From 12719b6e87e738f76d0456dd9a9d7571be58cb68 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Fri, 16 Jan 2026 16:34:22 -0800 Subject: [PATCH] MLX - dynamic loading of mlx-c (#13735) * MLX - dynamic loading of mlx-c Create a wrapper layer to indirect the dependency on mlx-c so the main ollama binary does not have a load-time dependency on mlx-c, mlx, and on linux, cuda. Lazy load the library via dlopen so we can adjust the path to ensure the dependencies are found and fail gracefully if not present. * review comments * fix broken tests --- Dockerfile | 18 +- MLX_VERSION | 1 + README.md | 4 +- scripts/build_darwin.sh | 20 +- x/imagegen/cmd/engine/generate.go | 8 +- x/imagegen/cmd/engine/main.go | 6 +- x/imagegen/mlx/compile.go | 2 +- x/imagegen/mlx/doc.go | 6 + x/imagegen/mlx/generate_wrappers.go | 439 ++ x/imagegen/mlx/mlx.c | 5786 +++++++++++++++++ x/imagegen/mlx/mlx.go | 245 +- x/imagegen/mlx/mlx.h | 2337 +++++++ x/imagegen/mlx/mlx_dynamic.c | 144 + x/imagegen/mlx/mlx_dynamic.h | 29 + x/imagegen/mlx/mlx_test.go | 21 + x/imagegen/models/qwen_image/pipeline_test.go | 21 + x/imagegen/models/qwen_image/qwen_image.go | 1 - .../models/qwen_image_edit/rope_test.go | 22 + x/imagegen/nn/nn_test.go | 22 + x/imagegen/runner/runner.go | 6 + x/imagegen/server.go | 9 +- x/ml/backend/mlx/CMakeLists.txt | 6 +- x/ml/backend/mlx/mlx_dynamic.c | 92 + x/ml/backend/mlx/mlx_dynamic.h | 26 + 24 files changed, 9043 insertions(+), 228 deletions(-) create mode 100644 MLX_VERSION create mode 100644 x/imagegen/mlx/doc.go create mode 100644 x/imagegen/mlx/generate_wrappers.go create mode 100644 x/imagegen/mlx/mlx.c create mode 100644 x/imagegen/mlx/mlx.h create mode 100644 x/imagegen/mlx/mlx_dynamic.c create mode 100644 x/imagegen/mlx/mlx_dynamic.h create mode 100644 x/ml/backend/mlx/mlx_dynamic.c create mode 100644 x/ml/backend/mlx/mlx_dynamic.h diff --git a/Dockerfile b/Dockerfile index bddf5c41e..1c0347fd6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -32,7 +32,7 @@ ENV PATH=/${VULKANVERSION}/x86_64/bin:$PATH FROM --platform=linux/arm64 almalinux:8 AS base-arm64 # install epel-release for ccache RUN yum install -y yum-utils epel-release \ - && dnf install -y clang ccache \ + && dnf install -y clang ccache git \ && yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo ENV CC=clang CXX=clang++ @@ -149,6 +149,7 @@ COPY CMakeLists.txt CMakePresets.json . COPY ml/backend/ggml/ggml ml/backend/ggml/ggml COPY x/ml/backend/mlx x/ml/backend/mlx COPY go.mod go.sum . +COPY MLX_VERSION . RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local ENV PATH=/usr/local/go/bin:$PATH RUN go mod download @@ -156,14 +157,6 @@ RUN --mount=type=cache,target=/root/.ccache \ cmake --preset 'MLX CUDA 13' -DBLAS_INCLUDE_DIRS=/usr/include/openblas -DLAPACK_INCLUDE_DIRS=/usr/include/openblas \ && cmake --build --parallel ${PARALLEL} --preset 'MLX CUDA 13' \ && cmake --install build --component MLX --strip --parallel ${PARALLEL} -COPY . . -ARG GOFLAGS="'-ldflags=-w -s'" -ENV CGO_ENABLED=1 -ARG CGO_CFLAGS -ARG CGO_CXXFLAGS -RUN mkdir -p dist/bin -RUN --mount=type=cache,target=/root/.cache/go-build \ - go build -tags mlx -trimpath -buildmode=pie -o dist/bin/ollama-mlx . FROM base AS build WORKDIR /go/src/github.com/ollama/ollama @@ -172,12 +165,14 @@ RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux- ENV PATH=/usr/local/go/bin:$PATH RUN go mod download COPY . . +# Clone mlx-c headers for CGO (version from MLX_VERSION file) +RUN git clone --depth 1 --branch "$(cat MLX_VERSION)" https://github.com/ml-explore/mlx-c.git build/_deps/mlx-c-src ARG GOFLAGS="'-ldflags=-w -s'" ENV CGO_ENABLED=1 -ARG CGO_CFLAGS +ENV CGO_CFLAGS="-I/go/src/github.com/ollama/ollama/build/_deps/mlx-c-src" ARG CGO_CXXFLAGS RUN --mount=type=cache,target=/root/.cache/go-build \ - go build -trimpath -buildmode=pie -o /bin/ollama . + go build -tags mlx -trimpath -buildmode=pie -o /bin/ollama . FROM --platform=linux/amd64 scratch AS amd64 # COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/ @@ -185,7 +180,6 @@ COPY --from=cuda-12 dist/lib/ollama /lib/ollama/ COPY --from=cuda-13 dist/lib/ollama /lib/ollama/ COPY --from=vulkan dist/lib/ollama /lib/ollama/ COPY --from=mlx /go/src/github.com/ollama/ollama/dist/lib/ollama /lib/ollama/ -COPY --from=mlx /go/src/github.com/ollama/ollama/dist/bin/ /bin/ FROM --platform=linux/arm64 scratch AS arm64 # COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/ diff --git a/MLX_VERSION b/MLX_VERSION new file mode 100644 index 000000000..5aff472dd --- /dev/null +++ b/MLX_VERSION @@ -0,0 +1 @@ +v0.4.1 diff --git a/README.md b/README.md index bda6b4c34..f9cee2af3 100644 --- a/README.md +++ b/README.md @@ -270,10 +270,10 @@ cmake --build --preset MLX --parallel cmake --install build --component MLX ``` -Next, build the `ollama-mlx` binary, which is a separate build of the Ollama runtime with MLX support enabled (needs to be in the same directory as `ollama`): +When building with the `-tags mlx` flag, the main `ollama` binary includes MLX support for experimental features like image generation: ```shell -go build -tags mlx -o ollama-mlx . +go build -tags mlx . ``` Finally, start the server: diff --git a/scripts/build_darwin.sh b/scripts/build_darwin.sh index ed5f97aa5..3560520ff 100755 --- a/scripts/build_darwin.sh +++ b/scripts/build_darwin.sh @@ -60,7 +60,7 @@ _build_darwin() { cmake --install $BUILD_DIR --component MLX # Override CGO flags to point to the amd64 build directory MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0" - MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Accelerate -mmacosx-version-min=14.0" + MLX_CGO_LDFLAGS="-ldl -lc++ -framework Accelerate -mmacosx-version-min=14.0" else BUILD_DIR=build cmake --preset MLX \ @@ -71,10 +71,12 @@ _build_darwin() { cmake --install $BUILD_DIR --component MLX # Use default CGO flags from mlx.go for arm64 MLX_CGO_CFLAGS="-O3 -I$(pwd)/$BUILD_DIR/_deps/mlx-c-src -mmacosx-version-min=14.0" - MLX_CGO_LDFLAGS="-L$(pwd)/$BUILD_DIR/lib/ollama -lmlxc -lmlx -Wl,-rpath,@executable_path -lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0" + MLX_CGO_LDFLAGS="-lc++ -framework Metal -framework Foundation -framework Accelerate -mmacosx-version-min=14.0" fi - GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX/ollama-mlx . - GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 go build -o $INSTALL_PREFIX . + GOOS=darwin GOARCH=$ARCH CGO_ENABLED=1 CGO_CFLAGS="$MLX_CGO_CFLAGS" CGO_LDFLAGS="$MLX_CGO_LDFLAGS" go build -tags mlx -o $INSTALL_PREFIX . + # Copy MLX libraries to same directory as executable for dlopen + cp $INSTALL_PREFIX/lib/ollama/libmlxc.dylib $INSTALL_PREFIX/ + cp $INSTALL_PREFIX/lib/ollama/libmlx.dylib $INSTALL_PREFIX/ done } @@ -82,12 +84,10 @@ _sign_darwin() { status "Creating universal binary..." mkdir -p dist/darwin lipo -create -output dist/darwin/ollama dist/darwin-*/ollama - lipo -create -output dist/darwin/ollama-mlx dist/darwin-*/ollama-mlx chmod +x dist/darwin/ollama - chmod +x dist/darwin/ollama-mlx if [ -n "$APPLE_IDENTITY" ]; then - for F in dist/darwin/ollama dist/darwin-*/lib/ollama/* dist/darwin/ollama-mlx; do + for F in dist/darwin/ollama dist/darwin-*/lib/ollama/*; do codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime $F done @@ -154,7 +154,6 @@ _build_macapp() { mkdir -p dist/Ollama.app/Contents/Resources if [ -d dist/darwin-amd64 ]; then lipo -create -output dist/Ollama.app/Contents/Resources/ollama dist/darwin-amd64/ollama dist/darwin-arm64/ollama - lipo -create -output dist/Ollama.app/Contents/Resources/ollama-mlx dist/darwin-amd64/ollama-mlx dist/darwin-arm64/ollama-mlx for F in dist/darwin-amd64/lib/ollama/*mlx*.dylib ; do lipo -create -output dist/darwin/$(basename $F) $F dist/darwin-arm64/lib/ollama/$(basename $F) done @@ -166,13 +165,12 @@ _build_macapp() { cp -a dist/darwin/ollama dist/Ollama.app/Contents/Resources/ollama cp dist/darwin/*.so dist/darwin/*.dylib dist/Ollama.app/Contents/Resources/ fi - cp -a dist/darwin/ollama-mlx dist/Ollama.app/Contents/Resources/ollama-mlx chmod a+x dist/Ollama.app/Contents/Resources/ollama # Sign if [ -n "$APPLE_IDENTITY" ]; then codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/Ollama.app/Contents/Resources/ollama - for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/*.metallib dist/Ollama.app/Contents/Resources/ollama-mlx ; do + for lib in dist/Ollama.app/Contents/Resources/*.so dist/Ollama.app/Contents/Resources/*.dylib dist/Ollama.app/Contents/Resources/*.metallib ; do codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime ${lib} done codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier com.electron.ollama --deep --options=runtime dist/Ollama.app @@ -180,7 +178,7 @@ _build_macapp() { rm -f dist/Ollama-darwin.zip ditto -c -k --norsrc --keepParent dist/Ollama.app dist/Ollama-darwin.zip - (cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama ollama-mlx *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz + (cd dist/Ollama.app/Contents/Resources/; tar -cf - ollama *.so *.dylib *.metallib 2>/dev/null) | gzip -9vc > dist/ollama-darwin.tgz # Notarize and Staple if [ -n "$APPLE_IDENTITY" ]; then diff --git a/x/imagegen/cmd/engine/generate.go b/x/imagegen/cmd/engine/generate.go index 506a48c54..51118afc1 100644 --- a/x/imagegen/cmd/engine/generate.go +++ b/x/imagegen/cmd/engine/generate.go @@ -65,12 +65,12 @@ func (s *utf8Streamer) Flush() string { return result } -func init() { - generationStream = mlx.NewStream() -} - // withStream runs fn with the generation stream as default func withStream(fn func()) { + // Lazy initialization of generationStream + if generationStream == nil { + generationStream = mlx.NewStream() + } orig := mlx.GetDefaultStream() mlx.SetDefaultStream(generationStream) fn() diff --git a/x/imagegen/cmd/engine/main.go b/x/imagegen/cmd/engine/main.go index bd1f871f6..02da278bc 100644 --- a/x/imagegen/cmd/engine/main.go +++ b/x/imagegen/cmd/engine/main.go @@ -12,7 +12,6 @@ import ( "path/filepath" "runtime/pprof" - "github.com/ollama/ollama/x/imagegen" "github.com/ollama/ollama/x/imagegen/mlx" "github.com/ollama/ollama/x/imagegen/models/gemma3" "github.com/ollama/ollama/x/imagegen/models/gpt_oss" @@ -79,6 +78,11 @@ func main() { return } + // Check if MLX initialized successfully + if !mlx.IsMLXAvailable() { + log.Fatalf("MLX initialization failed: %v", mlx.GetMLXInitError()) + } + // CPU profiling if *cpuProfile != "" { f, err := os.Create(*cpuProfile) diff --git a/x/imagegen/mlx/compile.go b/x/imagegen/mlx/compile.go index 36de65c5f..0dd2dd02a 100644 --- a/x/imagegen/mlx/compile.go +++ b/x/imagegen/mlx/compile.go @@ -3,7 +3,7 @@ package mlx /* -#include "mlx/c/mlx.h" +#include "mlx.h" #include // Forward declaration for Go callback diff --git a/x/imagegen/mlx/doc.go b/x/imagegen/mlx/doc.go new file mode 100644 index 000000000..ced1802b0 --- /dev/null +++ b/x/imagegen/mlx/doc.go @@ -0,0 +1,6 @@ +//go:build mlx + +// Package mlx provides Go bindings for the MLX-C library with dynamic loading support. +// +//go:generate go run generate_wrappers.go ../../../build/_deps/mlx-c-src/mlx/c mlx.h mlx.c +package mlx diff --git a/x/imagegen/mlx/generate_wrappers.go b/x/imagegen/mlx/generate_wrappers.go new file mode 100644 index 000000000..a55def02b --- /dev/null +++ b/x/imagegen/mlx/generate_wrappers.go @@ -0,0 +1,439 @@ +//go:build ignore + +// This tool generates MLX-C dynamic loading wrappers. +// Usage: go run generate_wrappers.go [output-impl] +package main + +import ( + "bytes" + "flag" + "fmt" + "io/fs" + "os" + "path/filepath" + "regexp" + "strings" +) + +type Function struct { + Name string + ReturnType string + Params string + ParamNames []string + NeedsARM64Guard bool +} + +func findHeaders(directory string) ([]string, error) { + var headers []string + err := filepath.WalkDir(directory, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if !d.IsDir() && strings.HasSuffix(path, ".h") { + headers = append(headers, path) + } + return nil + }) + return headers, err +} + +func cleanContent(content string) string { + // Remove single-line comments + re := regexp.MustCompile(`//.*?\n`) + content = re.ReplaceAllString(content, "\n") + + // Remove multi-line comments + re = regexp.MustCompile(`/\*.*?\*/`) + content = re.ReplaceAllString(content, "") + + // Remove preprocessor directives (lines starting with #) - use multiline mode + re = regexp.MustCompile(`(?m)^\s*#.*?$`) + content = re.ReplaceAllString(content, "") + + // Remove extern "C" { and } blocks more conservatively + // Only remove the extern "C" { line, not the content inside + re = regexp.MustCompile(`extern\s+"C"\s*\{\s*?\n`) + content = re.ReplaceAllString(content, "\n") + // Remove standalone closing braces that are not part of function declarations + re = regexp.MustCompile(`\n\s*\}\s*\n`) + content = re.ReplaceAllString(content, "\n") + + // Collapse whitespace and newlines + re = regexp.MustCompile(`\s+`) + content = re.ReplaceAllString(content, " ") + + return content +} + +func extractParamNames(params string) []string { + if params == "" || strings.TrimSpace(params) == "void" { + return []string{} + } + + var names []string + + // Split by comma, but respect parentheses (for function pointers) + parts := splitParams(params) + + // Remove array brackets + arrayBrackets := regexp.MustCompile(`\[.*?\]`) + + // Function pointer pattern + funcPtrPattern := regexp.MustCompile(`\(\s*\*\s*(\w+)\s*\)`) + + // Type keywords to skip + typeKeywords := map[string]bool{ + "const": true, + "struct": true, + "unsigned": true, + "signed": true, + "long": true, + "short": true, + "int": true, + "char": true, + "float": true, + "double": true, + "void": true, + "size_t": true, + "uint8_t": true, + "uint16_t": true, + "uint32_t": true, + "uint64_t": true, + "int8_t": true, + "int16_t": true, + "int32_t": true, + "int64_t": true, + "intptr_t": true, + "uintptr_t": true, + } + + for _, part := range parts { + if part == "" { + continue + } + + // Remove array brackets + part = arrayBrackets.ReplaceAllString(part, "") + + // For function pointers like "void (*callback)(int)" + if matches := funcPtrPattern.FindStringSubmatch(part); len(matches) > 1 { + names = append(names, matches[1]) + continue + } + + // Regular parameter: last identifier + tokens := regexp.MustCompile(`\w+`).FindAllString(part, -1) + if len(tokens) > 0 { + // The last token is usually the parameter name + // Skip type keywords + for i := len(tokens) - 1; i >= 0; i-- { + if !typeKeywords[tokens[i]] { + names = append(names, tokens[i]) + break + } + } + } + } + + return names +} + +func splitParams(params string) []string { + var parts []string + var current bytes.Buffer + depth := 0 + + for _, char := range params + "," { + switch char { + case '(': + depth++ + current.WriteRune(char) + case ')': + depth-- + current.WriteRune(char) + case ',': + if depth == 0 { + parts = append(parts, strings.TrimSpace(current.String())) + current.Reset() + } else { + current.WriteRune(char) + } + default: + current.WriteRune(char) + } + } + + return parts +} + +func parseFunctions(content string) []Function { + var functions []Function + + // Match function declarations: return_type function_name(params); + // Matches both mlx_* and _mlx_* functions + pattern := regexp.MustCompile(`\b((?:const\s+)?(?:struct\s+)?[\w\s]+?[\*\s]*)\s+(_?mlx_\w+)\s*\(([^)]*(?:\([^)]*\)[^)]*)*)\)\s*;`) + + matches := pattern.FindAllStringSubmatch(content, -1) + for _, match := range matches { + returnType := strings.TrimSpace(match[1]) + funcName := strings.TrimSpace(match[2]) + params := strings.TrimSpace(match[3]) + + // Skip if this looks like a variable declaration + if params == "" || strings.Contains(params, "{") { + continue + } + + // Clean up return type + returnType = strings.Join(strings.Fields(returnType), " ") + + // Extract parameter names + paramNames := extractParamNames(params) + + // Check if ARM64 guard is needed + needsGuard := needsARM64Guard(funcName, returnType, params) + + functions = append(functions, Function{ + Name: funcName, + ReturnType: returnType, + Params: params, + ParamNames: paramNames, + NeedsARM64Guard: needsGuard, + }) + } + + return functions +} + +func needsARM64Guard(name, retType, params string) bool { + return strings.Contains(name, "float16") || + strings.Contains(name, "bfloat16") || + strings.Contains(retType, "float16_t") || + strings.Contains(retType, "bfloat16_t") || + strings.Contains(params, "float16_t") || + strings.Contains(params, "bfloat16_t") +} + +func generateWrapperFiles(functions []Function, headerPath, implPath string) error { + // Generate header file + var headerBuf bytes.Buffer + + headerBuf.WriteString("// AUTO-GENERATED by generate_wrappers.go - DO NOT EDIT\n") + headerBuf.WriteString("// This file provides wrapper declarations for MLX-C functions that use dlopen/dlsym\n") + headerBuf.WriteString("//\n") + headerBuf.WriteString("// Strategy: Include MLX-C headers for type definitions, then provide wrapper\n") + headerBuf.WriteString("// functions that shadow the originals, allowing Go code to call them directly (e.g., C.mlx_add).\n") + headerBuf.WriteString("// Function pointers are defined in mlx.c (single compilation unit).\n\n") + headerBuf.WriteString("#ifndef MLX_WRAPPERS_H\n") + headerBuf.WriteString("#define MLX_WRAPPERS_H\n\n") + + headerBuf.WriteString("// Include MLX headers for type definitions and original declarations\n") + headerBuf.WriteString("#include \"mlx/c/mlx.h\"\n") + headerBuf.WriteString("#include \"mlx_dynamic.h\"\n") + headerBuf.WriteString("#include \n\n") + + // Undef all MLX functions to avoid conflicts + headerBuf.WriteString("// Undefine any existing MLX function macros\n") + for _, fn := range functions { + headerBuf.WriteString(fmt.Sprintf("#undef %s\n", fn.Name)) + } + headerBuf.WriteString("\n") + + // Function pointer extern declarations + headerBuf.WriteString("// Function pointer declarations (defined in mlx.c, loaded via dlsym)\n") + for _, fn := range functions { + if fn.NeedsARM64Guard { + headerBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n") + } + headerBuf.WriteString(fmt.Sprintf("extern %s (*%s_ptr)(%s);\n", fn.ReturnType, fn.Name, fn.Params)) + if fn.NeedsARM64Guard { + headerBuf.WriteString("#endif\n") + } + } + headerBuf.WriteString("\n") + + // Initialization function declaration + headerBuf.WriteString("// Initialize all function pointers via dlsym (defined in mlx.c)\n") + headerBuf.WriteString("int mlx_load_functions(void* handle);\n\n") + + // Wrapper function declarations + headerBuf.WriteString("// Wrapper function declarations that call through function pointers\n") + headerBuf.WriteString("// Go code calls these directly as C.mlx_* (no #define redirection needed)\n") + for _, fn := range functions { + if fn.NeedsARM64Guard { + headerBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n") + } + headerBuf.WriteString(fmt.Sprintf("%s %s(%s);\n", fn.ReturnType, fn.Name, fn.Params)) + if fn.NeedsARM64Guard { + headerBuf.WriteString("#endif\n") + } + headerBuf.WriteString("\n") + } + + headerBuf.WriteString("#endif // MLX_WRAPPERS_H\n") + + // Write header file + if err := os.WriteFile(headerPath, headerBuf.Bytes(), 0644); err != nil { + return fmt.Errorf("failed to write header file: %w", err) + } + + // Generate implementation file + var implBuf bytes.Buffer + + implBuf.WriteString("// AUTO-GENERATED by generate_wrappers.go - DO NOT EDIT\n") + implBuf.WriteString("// This file contains the function pointer definitions and initialization\n") + implBuf.WriteString("// All function pointers are in a single compilation unit to avoid duplication\n\n") + + implBuf.WriteString("#include \"mlx/c/mlx.h\"\n") + implBuf.WriteString("#include \"mlx_dynamic.h\"\n") + implBuf.WriteString("#include \n") + implBuf.WriteString("#include \n\n") + + // Function pointer definitions + implBuf.WriteString("// Function pointer definitions\n") + for _, fn := range functions { + if fn.NeedsARM64Guard { + implBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n") + } + implBuf.WriteString(fmt.Sprintf("%s (*%s_ptr)(%s) = NULL;\n", fn.ReturnType, fn.Name, fn.Params)) + if fn.NeedsARM64Guard { + implBuf.WriteString("#endif\n") + } + } + implBuf.WriteString("\n") + + // Initialization function + implBuf.WriteString("// Initialize all function pointers via dlsym\n") + implBuf.WriteString("int mlx_load_functions(void* handle) {\n") + implBuf.WriteString(" if (handle == NULL) {\n") + implBuf.WriteString(" fprintf(stderr, \"MLX: Invalid library handle\\n\");\n") + implBuf.WriteString(" return -1;\n") + implBuf.WriteString(" }\n\n") + + for _, fn := range functions { + if fn.NeedsARM64Guard { + implBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n") + } + implBuf.WriteString(fmt.Sprintf(" %s_ptr = dlsym(handle, \"%s\");\n", fn.Name, fn.Name)) + implBuf.WriteString(fmt.Sprintf(" if (%s_ptr == NULL) {\n", fn.Name)) + implBuf.WriteString(fmt.Sprintf(" fprintf(stderr, \"MLX: Failed to load symbol: %s\\n\");\n", fn.Name)) + implBuf.WriteString(" return -1;\n") + implBuf.WriteString(" }\n") + if fn.NeedsARM64Guard { + implBuf.WriteString("#endif\n") + } + } + + implBuf.WriteString(" return 0;\n") + implBuf.WriteString("}\n\n") + + // Wrapper function implementations + implBuf.WriteString("// Wrapper function implementations that call through function pointers\n") + for _, fn := range functions { + if fn.NeedsARM64Guard { + implBuf.WriteString("#if defined(__aarch64__) || defined(_M_ARM64)\n") + } + implBuf.WriteString(fmt.Sprintf("%s %s(%s) {\n", fn.ReturnType, fn.Name, fn.Params)) + + // Call through function pointer + if fn.ReturnType != "void" { + implBuf.WriteString(fmt.Sprintf(" return %s_ptr(", fn.Name)) + } else { + implBuf.WriteString(fmt.Sprintf(" %s_ptr(", fn.Name)) + } + + // Pass parameters + implBuf.WriteString(strings.Join(fn.ParamNames, ", ")) + implBuf.WriteString(");\n") + implBuf.WriteString("}\n") + if fn.NeedsARM64Guard { + implBuf.WriteString("#endif\n") + } + implBuf.WriteString("\n") + } + + // Write implementation file + if err := os.WriteFile(implPath, implBuf.Bytes(), 0644); err != nil { + return fmt.Errorf("failed to write implementation file: %w", err) + } + + return nil +} + +func main() { + flag.Usage = func() { + fmt.Fprintf(flag.CommandLine.Output(), "Usage: go run generate_wrappers.go [output-impl]\n") + fmt.Fprintf(flag.CommandLine.Output(), "Generate MLX-C dynamic loading wrappers.\n\n") + flag.PrintDefaults() + } + flag.Parse() + + args := flag.Args() + if len(args) < 2 { + fmt.Fprintf(flag.CommandLine.Output(), "ERROR: Missing required arguments\n\n") + flag.Usage() + os.Exit(1) + } + + headerDir := args[0] + outputHeader := args[1] + // Default implementation file is same name with .c extension + outputImpl := outputHeader + if len(args) > 2 { + outputImpl = args[2] + } else if strings.HasSuffix(outputHeader, ".h") { + outputImpl = outputHeader[:len(outputHeader)-2] + ".c" + } + + // Check if header directory exists + if _, err := os.Stat(headerDir); os.IsNotExist(err) { + fmt.Fprintf(os.Stderr, "ERROR: MLX-C headers directory not found at: %s\n\n", headerDir) + fmt.Fprintf(os.Stderr, "Please run CMake first to download MLX-C dependencies:\n") + fmt.Fprintf(os.Stderr, " cmake -B build\n\n") + fmt.Fprintf(os.Stderr, "The CMake build will download and extract MLX-C headers needed for wrapper generation.\n") + os.Exit(1) + } + + fmt.Fprintf(os.Stderr, "Parsing MLX-C headers from: %s\n", headerDir) + + // Find all headers + headers, err := findHeaders(headerDir) + if err != nil { + fmt.Fprintf(os.Stderr, "ERROR: Failed to find header files: %v\n", err) + os.Exit(1) + } + fmt.Fprintf(os.Stderr, "Found %d header files\n", len(headers)) + + // Parse all headers + var allFunctions []Function + seen := make(map[string]bool) + + for _, header := range headers { + content, err := os.ReadFile(header) + if err != nil { + fmt.Fprintf(os.Stderr, "Error reading %s: %v\n", header, err) + continue + } + + cleaned := cleanContent(string(content)) + functions := parseFunctions(cleaned) + + // Deduplicate + for _, fn := range functions { + if !seen[fn.Name] { + seen[fn.Name] = true + allFunctions = append(allFunctions, fn) + } + } + } + + fmt.Fprintf(os.Stderr, "Found %d unique function declarations\n", len(allFunctions)) + + // Generate wrapper files + if err := generateWrapperFiles(allFunctions, outputHeader, outputImpl); err != nil { + fmt.Fprintf(os.Stderr, "ERROR: Failed to generate wrapper files: %v\n", err) + os.Exit(1) + } + + fmt.Fprintf(os.Stderr, "Generated %s and %s successfully\n", outputHeader, outputImpl) +} diff --git a/x/imagegen/mlx/mlx.c b/x/imagegen/mlx/mlx.c new file mode 100644 index 000000000..564076f30 --- /dev/null +++ b/x/imagegen/mlx/mlx.c @@ -0,0 +1,5786 @@ +// AUTO-GENERATED by generate_wrappers.go - DO NOT EDIT +// This file contains the function pointer definitions and initialization +// All function pointers are in a single compilation unit to avoid duplication + +#include "mlx/c/mlx.h" +#include "mlx_dynamic.h" +#include +#include + +// Function pointer definitions +size_t (*mlx_dtype_size_ptr)(mlx_dtype dtype) = NULL; +int (*mlx_array_tostring_ptr)(mlx_string* str, const mlx_array arr) = NULL; +mlx_array (*mlx_array_new_ptr)(void) = NULL; +int (*mlx_array_free_ptr)(mlx_array arr) = NULL; +mlx_array (*mlx_array_new_bool_ptr)(bool val) = NULL; +mlx_array (*mlx_array_new_int_ptr)(int val) = NULL; +mlx_array (*mlx_array_new_float32_ptr)(float val) = NULL; +mlx_array (*mlx_array_new_float_ptr)(float val) = NULL; +mlx_array (*mlx_array_new_float64_ptr)(double val) = NULL; +mlx_array (*mlx_array_new_double_ptr)(double val) = NULL; +mlx_array (*mlx_array_new_complex_ptr)(float real_val, float imag_val) = NULL; +mlx_array (*mlx_array_new_data_ptr)(const void* data, const int* shape, int dim, mlx_dtype dtype) = NULL; +int (*mlx_array_set_ptr)(mlx_array* arr, const mlx_array src) = NULL; +int (*mlx_array_set_bool_ptr)(mlx_array* arr, bool val) = NULL; +int (*mlx_array_set_int_ptr)(mlx_array* arr, int val) = NULL; +int (*mlx_array_set_float32_ptr)(mlx_array* arr, float val) = NULL; +int (*mlx_array_set_float_ptr)(mlx_array* arr, float val) = NULL; +int (*mlx_array_set_float64_ptr)(mlx_array* arr, double val) = NULL; +int (*mlx_array_set_double_ptr)(mlx_array* arr, double val) = NULL; +int (*mlx_array_set_complex_ptr)(mlx_array* arr, float real_val, float imag_val) = NULL; +int (*mlx_array_set_data_ptr)(mlx_array* arr, const void* data, const int* shape, int dim, mlx_dtype dtype) = NULL; +size_t (*mlx_array_itemsize_ptr)(const mlx_array arr) = NULL; +size_t (*mlx_array_size_ptr)(const mlx_array arr) = NULL; +size_t (*mlx_array_nbytes_ptr)(const mlx_array arr) = NULL; +size_t (*mlx_array_ndim_ptr)(const mlx_array arr) = NULL; +const int* (*mlx_array_shape_ptr)(const mlx_array arr) = NULL; +const size_t* (*mlx_array_strides_ptr)(const mlx_array arr) = NULL; +int (*mlx_array_dim_ptr)(const mlx_array arr, int dim) = NULL; +mlx_dtype (*mlx_array_dtype_ptr)(const mlx_array arr) = NULL; +int (*mlx_array_eval_ptr)(mlx_array arr) = NULL; +int (*mlx_array_item_bool_ptr)(bool* res, const mlx_array arr) = NULL; +int (*mlx_array_item_uint8_ptr)(uint8_t* res, const mlx_array arr) = NULL; +int (*mlx_array_item_uint16_ptr)(uint16_t* res, const mlx_array arr) = NULL; +int (*mlx_array_item_uint32_ptr)(uint32_t* res, const mlx_array arr) = NULL; +int (*mlx_array_item_uint64_ptr)(uint64_t* res, const mlx_array arr) = NULL; +int (*mlx_array_item_int8_ptr)(int8_t* res, const mlx_array arr) = NULL; +int (*mlx_array_item_int16_ptr)(int16_t* res, const mlx_array arr) = NULL; +int (*mlx_array_item_int32_ptr)(int32_t* res, const mlx_array arr) = NULL; +int (*mlx_array_item_int64_ptr)(int64_t* res, const mlx_array arr) = NULL; +int (*mlx_array_item_float32_ptr)(float* res, const mlx_array arr) = NULL; +int (*mlx_array_item_float64_ptr)(double* res, const mlx_array arr) = NULL; +int (*mlx_array_item_complex64_ptr)(float _Complex* res, const mlx_array arr) = NULL; +#if defined(__aarch64__) || defined(_M_ARM64) +int (*mlx_array_item_float16_ptr)(float16_t* res, const mlx_array arr) = NULL; +#endif +#if defined(__aarch64__) || defined(_M_ARM64) +int (*mlx_array_item_bfloat16_ptr)(bfloat16_t* res, const mlx_array arr) = NULL; +#endif +const bool* (*mlx_array_data_bool_ptr)(const mlx_array arr) = NULL; +const uint8_t* (*mlx_array_data_uint8_ptr)(const mlx_array arr) = NULL; +const uint16_t* (*mlx_array_data_uint16_ptr)(const mlx_array arr) = NULL; +const uint32_t* (*mlx_array_data_uint32_ptr)(const mlx_array arr) = NULL; +const uint64_t* (*mlx_array_data_uint64_ptr)(const mlx_array arr) = NULL; +const int8_t* (*mlx_array_data_int8_ptr)(const mlx_array arr) = NULL; +const int16_t* (*mlx_array_data_int16_ptr)(const mlx_array arr) = NULL; +const int32_t* (*mlx_array_data_int32_ptr)(const mlx_array arr) = NULL; +const int64_t* (*mlx_array_data_int64_ptr)(const mlx_array arr) = NULL; +const float* (*mlx_array_data_float32_ptr)(const mlx_array arr) = NULL; +const double* (*mlx_array_data_float64_ptr)(const mlx_array arr) = NULL; +const float _Complex* (*mlx_array_data_complex64_ptr)(const mlx_array arr) = NULL; +#if defined(__aarch64__) || defined(_M_ARM64) +const float16_t* (*mlx_array_data_float16_ptr)(const mlx_array arr) = NULL; +#endif +#if defined(__aarch64__) || defined(_M_ARM64) +const bfloat16_t* (*mlx_array_data_bfloat16_ptr)(const mlx_array arr) = NULL; +#endif +int (*_mlx_array_is_available_ptr)(bool* res, const mlx_array arr) = NULL; +int (*_mlx_array_wait_ptr)(const mlx_array arr) = NULL; +int (*_mlx_array_is_contiguous_ptr)(bool* res, const mlx_array arr) = NULL; +int (*_mlx_array_is_row_contiguous_ptr)(bool* res, const mlx_array arr) = NULL; +int (*_mlx_array_is_col_contiguous_ptr)(bool* res, const mlx_array arr) = NULL; +mlx_closure (*mlx_closure_new_ptr)(void) = NULL; +int (*mlx_closure_free_ptr)(mlx_closure cls) = NULL; +mlx_closure (*mlx_closure_new_func_ptr)(int (*fun)(mlx_vector_array*, const mlx_vector_array)) = NULL; +mlx_closure (*mlx_closure_new_func_payload_ptr)(int (*fun)(mlx_vector_array*, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)) = NULL; +int (*mlx_closure_set_ptr)(mlx_closure* cls, const mlx_closure src) = NULL; +int (*mlx_closure_apply_ptr)(mlx_vector_array* res, mlx_closure cls, const mlx_vector_array input) = NULL; +mlx_closure (*mlx_closure_new_unary_ptr)(int (*fun)(mlx_array*, const mlx_array)) = NULL; +mlx_closure_kwargs (*mlx_closure_kwargs_new_ptr)(void) = NULL; +int (*mlx_closure_kwargs_free_ptr)(mlx_closure_kwargs cls) = NULL; +mlx_closure_kwargs (*mlx_closure_kwargs_new_func_ptr)(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_map_string_to_array)) = NULL; +mlx_closure_kwargs (*mlx_closure_kwargs_new_func_payload_ptr)(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_map_string_to_array, void*), void* payload, void (*dtor)(void*)) = NULL; +int (*mlx_closure_kwargs_set_ptr)(mlx_closure_kwargs* cls, const mlx_closure_kwargs src) = NULL; +int (*mlx_closure_kwargs_apply_ptr)(mlx_vector_array* res, mlx_closure_kwargs cls, const mlx_vector_array input_0, const mlx_map_string_to_array input_1) = NULL; +mlx_closure_value_and_grad (*mlx_closure_value_and_grad_new_ptr)(void) = NULL; +int (*mlx_closure_value_and_grad_free_ptr)(mlx_closure_value_and_grad cls) = NULL; +mlx_closure_value_and_grad (*mlx_closure_value_and_grad_new_func_ptr)(int (*fun)(mlx_vector_array*, mlx_vector_array*, const mlx_vector_array)) = NULL; +mlx_closure_value_and_grad (*mlx_closure_value_and_grad_new_func_payload_ptr)(int (*fun)( mlx_vector_array*, mlx_vector_array*, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)) = NULL; +int (*mlx_closure_value_and_grad_set_ptr)(mlx_closure_value_and_grad* cls, const mlx_closure_value_and_grad src) = NULL; +int (*mlx_closure_value_and_grad_apply_ptr)(mlx_vector_array* res_0, mlx_vector_array* res_1, mlx_closure_value_and_grad cls, const mlx_vector_array input) = NULL; +mlx_closure_custom (*mlx_closure_custom_new_ptr)(void) = NULL; +int (*mlx_closure_custom_free_ptr)(mlx_closure_custom cls) = NULL; +mlx_closure_custom (*mlx_closure_custom_new_func_ptr)(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const mlx_vector_array)) = NULL; +mlx_closure_custom (*mlx_closure_custom_new_func_payload_ptr)(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)) = NULL; +int (*mlx_closure_custom_set_ptr)(mlx_closure_custom* cls, const mlx_closure_custom src) = NULL; +int (*mlx_closure_custom_apply_ptr)(mlx_vector_array* res, mlx_closure_custom cls, const mlx_vector_array input_0, const mlx_vector_array input_1, const mlx_vector_array input_2) = NULL; +mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_ptr)(void) = NULL; +int (*mlx_closure_custom_jvp_free_ptr)(mlx_closure_custom_jvp cls) = NULL; +mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_ptr)(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const int*, size_t _num)) = NULL; +mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_payload_ptr)(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const int*, size_t _num, void*), void* payload, void (*dtor)(void*)) = NULL; +int (*mlx_closure_custom_jvp_set_ptr)(mlx_closure_custom_jvp* cls, const mlx_closure_custom_jvp src) = NULL; +int (*mlx_closure_custom_jvp_apply_ptr)(mlx_vector_array* res, mlx_closure_custom_jvp cls, const mlx_vector_array input_0, const mlx_vector_array input_1, const int* input_2, size_t input_2_num) = NULL; +mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_ptr)(void) = NULL; +int (*mlx_closure_custom_vmap_free_ptr)(mlx_closure_custom_vmap cls) = NULL; +mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_ptr)(int (*fun)( mlx_vector_array*, mlx_vector_int*, const mlx_vector_array, const int*, size_t _num)) = NULL; +mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_payload_ptr)(int (*fun)( mlx_vector_array*, mlx_vector_int*, const mlx_vector_array, const int*, size_t _num, void*), void* payload, void (*dtor)(void*)) = NULL; +int (*mlx_closure_custom_vmap_set_ptr)(mlx_closure_custom_vmap* cls, const mlx_closure_custom_vmap src) = NULL; +int (*mlx_closure_custom_vmap_apply_ptr)(mlx_vector_array* res_0, mlx_vector_int* res_1, mlx_closure_custom_vmap cls, const mlx_vector_array input_0, const int* input_1, size_t input_1_num) = NULL; +int (*mlx_compile_ptr)(mlx_closure* res, const mlx_closure fun, bool shapeless) = NULL; +int (*mlx_detail_compile_ptr)(mlx_closure* res, const mlx_closure fun, uintptr_t fun_id, bool shapeless, const uint64_t* constants, size_t constants_num) = NULL; +int (*mlx_detail_compile_clear_cache_ptr)(void) = NULL; +int (*mlx_detail_compile_erase_ptr)(uintptr_t fun_id) = NULL; +int (*mlx_disable_compile_ptr)(void) = NULL; +int (*mlx_enable_compile_ptr)(void) = NULL; +int (*mlx_set_compile_mode_ptr)(mlx_compile_mode mode) = NULL; +mlx_device (*mlx_device_new_ptr)(void) = NULL; +mlx_device (*mlx_device_new_type_ptr)(mlx_device_type type, int index) = NULL; +int (*mlx_device_free_ptr)(mlx_device dev) = NULL; +int (*mlx_device_set_ptr)(mlx_device* dev, const mlx_device src) = NULL; +int (*mlx_device_tostring_ptr)(mlx_string* str, mlx_device dev) = NULL; +bool (*mlx_device_equal_ptr)(mlx_device lhs, mlx_device rhs) = NULL; +int (*mlx_device_get_index_ptr)(int* index, mlx_device dev) = NULL; +int (*mlx_device_get_type_ptr)(mlx_device_type* type, mlx_device dev) = NULL; +int (*mlx_get_default_device_ptr)(mlx_device* dev) = NULL; +int (*mlx_set_default_device_ptr)(mlx_device dev) = NULL; +int (*mlx_distributed_all_gather_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S) = NULL; +int (*mlx_distributed_all_max_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL; +int (*mlx_distributed_all_min_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL; +int (*mlx_distributed_all_sum_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL; +int (*mlx_distributed_recv_ptr)(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, int src, const mlx_distributed_group group , const mlx_stream s) = NULL; +int (*mlx_distributed_recv_like_ptr)(mlx_array* res, const mlx_array x, int src, const mlx_distributed_group group , const mlx_stream s) = NULL; +int (*mlx_distributed_send_ptr)(mlx_array* res, const mlx_array x, int dst, const mlx_distributed_group group , const mlx_stream s) = NULL; +int (*mlx_distributed_sum_scatter_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL; +int (*mlx_distributed_group_rank_ptr)(mlx_distributed_group group) = NULL; +int (*mlx_distributed_group_size_ptr)(mlx_distributed_group group) = NULL; +mlx_distributed_group (*mlx_distributed_group_split_ptr)(mlx_distributed_group group, int color, int key) = NULL; +bool (*mlx_distributed_is_available_ptr)(void) = NULL; +mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict) = NULL; +void (*mlx_set_error_handler_ptr)(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) = NULL; +void (*_mlx_error_ptr)(const char* file, const int line, const char* fmt, ...) = NULL; +int (*mlx_export_function_ptr)(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless) = NULL; +int (*mlx_export_function_kwargs_ptr)(const char* file, const mlx_closure_kwargs fun, const mlx_vector_array args, const mlx_map_string_to_array kwargs, bool shapeless) = NULL; +mlx_function_exporter (*mlx_function_exporter_new_ptr)(const char* file, const mlx_closure fun, bool shapeless) = NULL; +int (*mlx_function_exporter_free_ptr)(mlx_function_exporter xfunc) = NULL; +int (*mlx_function_exporter_apply_ptr)(const mlx_function_exporter xfunc, const mlx_vector_array args) = NULL; +int (*mlx_function_exporter_apply_kwargs_ptr)(const mlx_function_exporter xfunc, const mlx_vector_array args, const mlx_map_string_to_array kwargs) = NULL; +mlx_imported_function (*mlx_imported_function_new_ptr)(const char* file) = NULL; +int (*mlx_imported_function_free_ptr)(mlx_imported_function xfunc) = NULL; +int (*mlx_imported_function_apply_ptr)(mlx_vector_array* res, const mlx_imported_function xfunc, const mlx_vector_array args) = NULL; +int (*mlx_imported_function_apply_kwargs_ptr)(mlx_vector_array* res, const mlx_imported_function xfunc, const mlx_vector_array args, const mlx_map_string_to_array kwargs) = NULL; +mlx_fast_cuda_kernel_config (*mlx_fast_cuda_kernel_config_new_ptr)(void) = NULL; +void (*mlx_fast_cuda_kernel_config_free_ptr)(mlx_fast_cuda_kernel_config cls) = NULL; +int (*mlx_fast_cuda_kernel_config_add_output_arg_ptr)(mlx_fast_cuda_kernel_config cls, const int* shape, size_t size, mlx_dtype dtype) = NULL; +int (*mlx_fast_cuda_kernel_config_set_grid_ptr)(mlx_fast_cuda_kernel_config cls, int grid1, int grid2, int grid3) = NULL; +int (*mlx_fast_cuda_kernel_config_set_thread_group_ptr)(mlx_fast_cuda_kernel_config cls, int thread1, int thread2, int thread3) = NULL; +int (*mlx_fast_cuda_kernel_config_set_init_value_ptr)(mlx_fast_cuda_kernel_config cls, float value) = NULL; +int (*mlx_fast_cuda_kernel_config_set_verbose_ptr)(mlx_fast_cuda_kernel_config cls, bool verbose) = NULL; +int (*mlx_fast_cuda_kernel_config_add_template_arg_dtype_ptr)(mlx_fast_cuda_kernel_config cls, const char* name, mlx_dtype dtype) = NULL; +int (*mlx_fast_cuda_kernel_config_add_template_arg_int_ptr)(mlx_fast_cuda_kernel_config cls, const char* name, int value) = NULL; +int (*mlx_fast_cuda_kernel_config_add_template_arg_bool_ptr)(mlx_fast_cuda_kernel_config cls, const char* name, bool value) = NULL; +mlx_fast_cuda_kernel (*mlx_fast_cuda_kernel_new_ptr)(const char* name, const mlx_vector_string input_names, const mlx_vector_string output_names, const char* source, const char* header, bool ensure_row_contiguous, int shared_memory) = NULL; +void (*mlx_fast_cuda_kernel_free_ptr)(mlx_fast_cuda_kernel cls) = NULL; +int (*mlx_fast_cuda_kernel_apply_ptr)(mlx_vector_array* outputs, mlx_fast_cuda_kernel cls, const mlx_vector_array inputs, const mlx_fast_cuda_kernel_config config, const mlx_stream stream) = NULL; +int (*mlx_fast_layer_norm_ptr)(mlx_array* res, const mlx_array x, const mlx_array weight , const mlx_array bias , float eps, const mlx_stream s) = NULL; +mlx_fast_metal_kernel_config (*mlx_fast_metal_kernel_config_new_ptr)(void) = NULL; +void (*mlx_fast_metal_kernel_config_free_ptr)(mlx_fast_metal_kernel_config cls) = NULL; +int (*mlx_fast_metal_kernel_config_add_output_arg_ptr)(mlx_fast_metal_kernel_config cls, const int* shape, size_t size, mlx_dtype dtype) = NULL; +int (*mlx_fast_metal_kernel_config_set_grid_ptr)(mlx_fast_metal_kernel_config cls, int grid1, int grid2, int grid3) = NULL; +int (*mlx_fast_metal_kernel_config_set_thread_group_ptr)(mlx_fast_metal_kernel_config cls, int thread1, int thread2, int thread3) = NULL; +int (*mlx_fast_metal_kernel_config_set_init_value_ptr)(mlx_fast_metal_kernel_config cls, float value) = NULL; +int (*mlx_fast_metal_kernel_config_set_verbose_ptr)(mlx_fast_metal_kernel_config cls, bool verbose) = NULL; +int (*mlx_fast_metal_kernel_config_add_template_arg_dtype_ptr)(mlx_fast_metal_kernel_config cls, const char* name, mlx_dtype dtype) = NULL; +int (*mlx_fast_metal_kernel_config_add_template_arg_int_ptr)(mlx_fast_metal_kernel_config cls, const char* name, int value) = NULL; +int (*mlx_fast_metal_kernel_config_add_template_arg_bool_ptr)(mlx_fast_metal_kernel_config cls, const char* name, bool value) = NULL; +mlx_fast_metal_kernel (*mlx_fast_metal_kernel_new_ptr)(const char* name, const mlx_vector_string input_names, const mlx_vector_string output_names, const char* source, const char* header, bool ensure_row_contiguous, bool atomic_outputs) = NULL; +void (*mlx_fast_metal_kernel_free_ptr)(mlx_fast_metal_kernel cls) = NULL; +int (*mlx_fast_metal_kernel_apply_ptr)(mlx_vector_array* outputs, mlx_fast_metal_kernel cls, const mlx_vector_array inputs, const mlx_fast_metal_kernel_config config, const mlx_stream stream) = NULL; +int (*mlx_fast_rms_norm_ptr)(mlx_array* res, const mlx_array x, const mlx_array weight , float eps, const mlx_stream s) = NULL; +int (*mlx_fast_rope_ptr)(mlx_array* res, const mlx_array x, int dims, bool traditional, mlx_optional_float base, float scale, int offset, const mlx_array freqs , const mlx_stream s) = NULL; +int (*mlx_fast_rope_dynamic_ptr)(mlx_array* res, const mlx_array x, int dims, bool traditional, mlx_optional_float base, float scale, const mlx_array offset, const mlx_array freqs , const mlx_stream s) = NULL; +int (*mlx_fast_scaled_dot_product_attention_ptr)(mlx_array* res, const mlx_array queries, const mlx_array keys, const mlx_array values, float scale, const char* mask_mode, const mlx_array mask_arr , const mlx_array sinks , const mlx_stream s) = NULL; +int (*mlx_fft_fft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) = NULL; +int (*mlx_fft_fft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL; +int (*mlx_fft_fftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL; +int (*mlx_fft_fftshift_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) = NULL; +int (*mlx_fft_ifft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) = NULL; +int (*mlx_fft_ifft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL; +int (*mlx_fft_ifftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL; +int (*mlx_fft_ifftshift_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) = NULL; +int (*mlx_fft_irfft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) = NULL; +int (*mlx_fft_irfft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL; +int (*mlx_fft_irfftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL; +int (*mlx_fft_rfft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) = NULL; +int (*mlx_fft_rfft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL; +int (*mlx_fft_rfftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL; +int (*mlx_load_reader_ptr)(mlx_array* res, mlx_io_reader in_stream, const mlx_stream s) = NULL; +int (*mlx_load_ptr)(mlx_array* res, const char* file, const mlx_stream s) = NULL; +int (*mlx_load_safetensors_reader_ptr)(mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, mlx_io_reader in_stream, const mlx_stream s) = NULL; +int (*mlx_load_safetensors_ptr)(mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, const char* file, const mlx_stream s) = NULL; +int (*mlx_save_writer_ptr)(mlx_io_writer out_stream, const mlx_array a) = NULL; +int (*mlx_save_ptr)(const char* file, const mlx_array a) = NULL; +int (*mlx_save_safetensors_writer_ptr)(mlx_io_writer in_stream, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata) = NULL; +int (*mlx_save_safetensors_ptr)(const char* file, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata) = NULL; +mlx_io_reader (*mlx_io_reader_new_ptr)(void* desc, mlx_io_vtable vtable) = NULL; +int (*mlx_io_reader_descriptor_ptr)(void** desc_, mlx_io_reader io) = NULL; +int (*mlx_io_reader_tostring_ptr)(mlx_string* str_, mlx_io_reader io) = NULL; +int (*mlx_io_reader_free_ptr)(mlx_io_reader io) = NULL; +mlx_io_writer (*mlx_io_writer_new_ptr)(void* desc, mlx_io_vtable vtable) = NULL; +int (*mlx_io_writer_descriptor_ptr)(void** desc_, mlx_io_writer io) = NULL; +int (*mlx_io_writer_tostring_ptr)(mlx_string* str_, mlx_io_writer io) = NULL; +int (*mlx_io_writer_free_ptr)(mlx_io_writer io) = NULL; +int (*mlx_linalg_cholesky_ptr)(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s) = NULL; +int (*mlx_linalg_cholesky_inv_ptr)(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s) = NULL; +int (*mlx_linalg_cross_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int axis, const mlx_stream s) = NULL; +int (*mlx_linalg_eig_ptr)(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_linalg_eigh_ptr)(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const char* UPLO, const mlx_stream s) = NULL; +int (*mlx_linalg_eigvals_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_linalg_eigvalsh_ptr)(mlx_array* res, const mlx_array a, const char* UPLO, const mlx_stream s) = NULL; +int (*mlx_linalg_inv_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_linalg_lu_ptr)(mlx_vector_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_linalg_lu_factor_ptr)(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_linalg_norm_ptr)(mlx_array* res, const mlx_array a, double ord, const int* axis , size_t axis_num, bool keepdims, const mlx_stream s) = NULL; +int (*mlx_linalg_norm_matrix_ptr)(mlx_array* res, const mlx_array a, const char* ord, const int* axis , size_t axis_num, bool keepdims, const mlx_stream s) = NULL; +int (*mlx_linalg_norm_l2_ptr)(mlx_array* res, const mlx_array a, const int* axis , size_t axis_num, bool keepdims, const mlx_stream s) = NULL; +int (*mlx_linalg_pinv_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_linalg_qr_ptr)(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_linalg_solve_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; +int (*mlx_linalg_solve_triangular_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, bool upper, const mlx_stream s) = NULL; +int (*mlx_linalg_svd_ptr)(mlx_vector_array* res, const mlx_array a, bool compute_uv, const mlx_stream s) = NULL; +int (*mlx_linalg_tri_inv_ptr)(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s) = NULL; +mlx_map_string_to_array (*mlx_map_string_to_array_new_ptr)(void) = NULL; +int (*mlx_map_string_to_array_set_ptr)(mlx_map_string_to_array* map, const mlx_map_string_to_array src) = NULL; +int (*mlx_map_string_to_array_free_ptr)(mlx_map_string_to_array map) = NULL; +int (*mlx_map_string_to_array_insert_ptr)(mlx_map_string_to_array map, const char* key, const mlx_array value) = NULL; +int (*mlx_map_string_to_array_get_ptr)(mlx_array* value, const mlx_map_string_to_array map, const char* key) = NULL; +mlx_map_string_to_array_iterator (*mlx_map_string_to_array_iterator_new_ptr)(mlx_map_string_to_array map) = NULL; +int (*mlx_map_string_to_array_iterator_free_ptr)(mlx_map_string_to_array_iterator it) = NULL; +int (*mlx_map_string_to_array_iterator_next_ptr)(const char** key, mlx_array* value, mlx_map_string_to_array_iterator it) = NULL; +mlx_map_string_to_string (*mlx_map_string_to_string_new_ptr)(void) = NULL; +int (*mlx_map_string_to_string_set_ptr)(mlx_map_string_to_string* map, const mlx_map_string_to_string src) = NULL; +int (*mlx_map_string_to_string_free_ptr)(mlx_map_string_to_string map) = NULL; +int (*mlx_map_string_to_string_insert_ptr)(mlx_map_string_to_string map, const char* key, const char* value) = NULL; +int (*mlx_map_string_to_string_get_ptr)(const char** value, const mlx_map_string_to_string map, const char* key) = NULL; +mlx_map_string_to_string_iterator (*mlx_map_string_to_string_iterator_new_ptr)(mlx_map_string_to_string map) = NULL; +int (*mlx_map_string_to_string_iterator_free_ptr)(mlx_map_string_to_string_iterator it) = NULL; +int (*mlx_map_string_to_string_iterator_next_ptr)(const char** key, const char** value, mlx_map_string_to_string_iterator it) = NULL; +int (*mlx_clear_cache_ptr)(void) = NULL; +int (*mlx_get_active_memory_ptr)(size_t* res) = NULL; +int (*mlx_get_cache_memory_ptr)(size_t* res) = NULL; +int (*mlx_get_memory_limit_ptr)(size_t* res) = NULL; +int (*mlx_get_peak_memory_ptr)(size_t* res) = NULL; +int (*mlx_reset_peak_memory_ptr)(void) = NULL; +int (*mlx_set_cache_limit_ptr)(size_t* res, size_t limit) = NULL; +int (*mlx_set_memory_limit_ptr)(size_t* res, size_t limit) = NULL; +int (*mlx_set_wired_limit_ptr)(size_t* res, size_t limit) = NULL; +mlx_metal_device_info_t (*mlx_metal_device_info_ptr)(void) = NULL; +int (*mlx_metal_is_available_ptr)(bool* res) = NULL; +int (*mlx_metal_start_capture_ptr)(const char* path) = NULL; +int (*mlx_metal_stop_capture_ptr)(void) = NULL; +int (*mlx_abs_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_add_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; +int (*mlx_addmm_ptr)(mlx_array* res, const mlx_array c, const mlx_array a, const mlx_array b, float alpha, float beta, const mlx_stream s) = NULL; +int (*mlx_all_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL; +int (*mlx_all_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL; +int (*mlx_all_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL; +int (*mlx_allclose_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, double rtol, double atol, bool equal_nan, const mlx_stream s) = NULL; +int (*mlx_any_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL; +int (*mlx_any_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL; +int (*mlx_any_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL; +int (*mlx_arange_ptr)(mlx_array* res, double start, double stop, double step, mlx_dtype dtype, const mlx_stream s) = NULL; +int (*mlx_arccos_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_arccosh_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_arcsin_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_arcsinh_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_arctan_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_arctan2_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; +int (*mlx_arctanh_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_argmax_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL; +int (*mlx_argmax_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL; +int (*mlx_argmin_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL; +int (*mlx_argmin_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL; +int (*mlx_argpartition_axis_ptr)(mlx_array* res, const mlx_array a, int kth, int axis, const mlx_stream s) = NULL; +int (*mlx_argpartition_ptr)(mlx_array* res, const mlx_array a, int kth, const mlx_stream s) = NULL; +int (*mlx_argsort_axis_ptr)(mlx_array* res, const mlx_array a, int axis, const mlx_stream s) = NULL; +int (*mlx_argsort_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_array_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, bool equal_nan, const mlx_stream s) = NULL; +int (*mlx_as_strided_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const int64_t* strides, size_t strides_num, size_t offset, const mlx_stream s) = NULL; +int (*mlx_astype_ptr)(mlx_array* res, const mlx_array a, mlx_dtype dtype, const mlx_stream s) = NULL; +int (*mlx_atleast_1d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_atleast_2d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_atleast_3d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_bitwise_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; +int (*mlx_bitwise_invert_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_bitwise_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; +int (*mlx_bitwise_xor_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; +int (*mlx_block_masked_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) = NULL; +int (*mlx_broadcast_arrays_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s) = NULL; +int (*mlx_broadcast_to_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s) = NULL; +int (*mlx_ceil_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_clip_ptr)(mlx_array* res, const mlx_array a, const mlx_array a_min , const mlx_array a_max , const mlx_stream s) = NULL; +int (*mlx_concatenate_axis_ptr)(mlx_array* res, const mlx_vector_array arrays, int axis, const mlx_stream s) = NULL; +int (*mlx_concatenate_ptr)(mlx_array* res, const mlx_vector_array arrays, const mlx_stream s) = NULL; +int (*mlx_conjugate_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_contiguous_ptr)(mlx_array* res, const mlx_array a, bool allow_col_major, const mlx_stream s) = NULL; +int (*mlx_conv1d_ptr)(mlx_array* res, const mlx_array input, const mlx_array weight, int stride, int padding, int dilation, int groups, const mlx_stream s) = NULL; +int (*mlx_conv2d_ptr)(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int padding_0, int padding_1, int dilation_0, int dilation_1, int groups, const mlx_stream s) = NULL; +int (*mlx_conv3d_ptr)(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int stride_2, int padding_0, int padding_1, int padding_2, int dilation_0, int dilation_1, int dilation_2, int groups, const mlx_stream s) = NULL; +int (*mlx_conv_general_ptr)(mlx_array* res, const mlx_array input, const mlx_array weight, const int* stride, size_t stride_num, const int* padding_lo, size_t padding_lo_num, const int* padding_hi, size_t padding_hi_num, const int* kernel_dilation, size_t kernel_dilation_num, const int* input_dilation, size_t input_dilation_num, int groups, bool flip, const mlx_stream s) = NULL; +int (*mlx_conv_transpose1d_ptr)(mlx_array* res, const mlx_array input, const mlx_array weight, int stride, int padding, int dilation, int output_padding, int groups, const mlx_stream s) = NULL; +int (*mlx_conv_transpose2d_ptr)(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int padding_0, int padding_1, int dilation_0, int dilation_1, int output_padding_0, int output_padding_1, int groups, const mlx_stream s) = NULL; +int (*mlx_conv_transpose3d_ptr)(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int stride_2, int padding_0, int padding_1, int padding_2, int dilation_0, int dilation_1, int dilation_2, int output_padding_0, int output_padding_1, int output_padding_2, int groups, const mlx_stream s) = NULL; +int (*mlx_copy_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_cos_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_cosh_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_cummax_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) = NULL; +int (*mlx_cummin_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) = NULL; +int (*mlx_cumprod_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) = NULL; +int (*mlx_cumsum_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) = NULL; +int (*mlx_degrees_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_depends_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies) = NULL; +int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s) = NULL; +int (*mlx_diag_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL; +int (*mlx_diagonal_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s) = NULL; +int (*mlx_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; +int (*mlx_divmod_ptr)(mlx_vector_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; +int (*mlx_einsum_ptr)(mlx_array* res, const char* subscripts, const mlx_vector_array operands, const mlx_stream s) = NULL; +int (*mlx_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; +int (*mlx_erf_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_erfinv_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_exp_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_expand_dims_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) = NULL; +int (*mlx_expand_dims_ptr)(mlx_array* res, const mlx_array a, int axis, const mlx_stream s) = NULL; +int (*mlx_expm1_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_eye_ptr)(mlx_array* res, int n, int m, int k, mlx_dtype dtype, const mlx_stream s) = NULL; +int (*mlx_flatten_ptr)(mlx_array* res, const mlx_array a, int start_axis, int end_axis, const mlx_stream s) = NULL; +int (*mlx_floor_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_floor_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; +int (*mlx_from_fp8_ptr)(mlx_array* res, const mlx_array x, mlx_dtype dtype, const mlx_stream s) = NULL; +int (*mlx_full_ptr)(mlx_array* res, const int* shape, size_t shape_num, const mlx_array vals, mlx_dtype dtype, const mlx_stream s) = NULL; +int (*mlx_full_like_ptr)(mlx_array* res, const mlx_array a, const mlx_array vals, mlx_dtype dtype, const mlx_stream s) = NULL; +int (*mlx_gather_ptr)(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const int* axes, size_t axes_num, const int* slice_sizes, size_t slice_sizes_num, const mlx_stream s) = NULL; +int (*mlx_gather_single_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const int* slice_sizes, size_t slice_sizes_num, const mlx_stream s) = NULL; +int (*mlx_gather_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_array lhs_indices , const mlx_array rhs_indices , bool sorted_indices, const mlx_stream s) = NULL; +int (*mlx_gather_qmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , const mlx_array lhs_indices , const mlx_array rhs_indices , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, bool sorted_indices, const mlx_stream s) = NULL; +int (*mlx_greater_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; +int (*mlx_greater_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; +int (*mlx_hadamard_transform_ptr)(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s) = NULL; +int (*mlx_identity_ptr)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) = NULL; +int (*mlx_imag_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_inner_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; +int (*mlx_isclose_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, double rtol, double atol, bool equal_nan, const mlx_stream s) = NULL; +int (*mlx_isfinite_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_isinf_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_isnan_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_isneginf_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_isposinf_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_kron_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; +int (*mlx_left_shift_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; +int (*mlx_less_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; +int (*mlx_less_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; +int (*mlx_linspace_ptr)(mlx_array* res, double start, double stop, int num, mlx_dtype dtype, const mlx_stream s) = NULL; +int (*mlx_log_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_log10_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_log1p_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_log2_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_logaddexp_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; +int (*mlx_logcumsumexp_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) = NULL; +int (*mlx_logical_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; +int (*mlx_logical_not_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_logical_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; +int (*mlx_logsumexp_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL; +int (*mlx_logsumexp_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL; +int (*mlx_logsumexp_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL; +int (*mlx_masked_scatter_ptr)(mlx_array* res, const mlx_array a, const mlx_array mask, const mlx_array src, const mlx_stream s) = NULL; +int (*mlx_matmul_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; +int (*mlx_max_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL; +int (*mlx_max_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL; +int (*mlx_max_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL; +int (*mlx_maximum_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; +int (*mlx_mean_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL; +int (*mlx_mean_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL; +int (*mlx_mean_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL; +int (*mlx_median_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL; +int (*mlx_meshgrid_ptr)(mlx_vector_array* res, const mlx_vector_array arrays, bool sparse, const char* indexing, const mlx_stream s) = NULL; +int (*mlx_min_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL; +int (*mlx_min_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL; +int (*mlx_min_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL; +int (*mlx_minimum_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; +int (*mlx_moveaxis_ptr)(mlx_array* res, const mlx_array a, int source, int destination, const mlx_stream s) = NULL; +int (*mlx_multiply_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; +int (*mlx_nan_to_num_ptr)(mlx_array* res, const mlx_array a, float nan, mlx_optional_float posinf, mlx_optional_float neginf, const mlx_stream s) = NULL; +int (*mlx_negative_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_not_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; +int (*mlx_number_of_elements_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool inverted, mlx_dtype dtype, const mlx_stream s) = NULL; +int (*mlx_ones_ptr)(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_stream s) = NULL; +int (*mlx_ones_like_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_outer_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; +int (*mlx_pad_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const int* low_pad_size, size_t low_pad_size_num, const int* high_pad_size, size_t high_pad_size_num, const mlx_array pad_value, const char* mode, const mlx_stream s) = NULL; +int (*mlx_pad_symmetric_ptr)(mlx_array* res, const mlx_array a, int pad_width, const mlx_array pad_value, const char* mode, const mlx_stream s) = NULL; +int (*mlx_partition_axis_ptr)(mlx_array* res, const mlx_array a, int kth, int axis, const mlx_stream s) = NULL; +int (*mlx_partition_ptr)(mlx_array* res, const mlx_array a, int kth, const mlx_stream s) = NULL; +int (*mlx_power_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; +int (*mlx_prod_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL; +int (*mlx_prod_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL; +int (*mlx_prod_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL; +int (*mlx_put_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s) = NULL; +int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL; +int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL; +int (*mlx_quantized_matmul_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL; +int (*mlx_radians_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_real_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_reciprocal_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_remainder_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; +int (*mlx_repeat_axis_ptr)(mlx_array* res, const mlx_array arr, int repeats, int axis, const mlx_stream s) = NULL; +int (*mlx_repeat_ptr)(mlx_array* res, const mlx_array arr, int repeats, const mlx_stream s) = NULL; +int (*mlx_reshape_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s) = NULL; +int (*mlx_right_shift_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; +int (*mlx_roll_axis_ptr)(mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, int axis, const mlx_stream s) = NULL; +int (*mlx_roll_axes_ptr)(mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL; +int (*mlx_roll_ptr)(mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, const mlx_stream s) = NULL; +int (*mlx_round_ptr)(mlx_array* res, const mlx_array a, int decimals, const mlx_stream s) = NULL; +int (*mlx_rsqrt_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_scatter_ptr)(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) = NULL; +int (*mlx_scatter_single_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) = NULL; +int (*mlx_scatter_add_ptr)(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) = NULL; +int (*mlx_scatter_add_single_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) = NULL; +int (*mlx_scatter_add_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s) = NULL; +int (*mlx_scatter_max_ptr)(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) = NULL; +int (*mlx_scatter_max_single_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) = NULL; +int (*mlx_scatter_min_ptr)(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) = NULL; +int (*mlx_scatter_min_single_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) = NULL; +int (*mlx_scatter_prod_ptr)(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) = NULL; +int (*mlx_scatter_prod_single_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) = NULL; +int (*mlx_segmented_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_array segments, const mlx_stream s) = NULL; +int (*mlx_sigmoid_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_sign_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_sin_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_sinh_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_slice_ptr)(mlx_array* res, const mlx_array a, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s) = NULL; +int (*mlx_slice_dynamic_ptr)(mlx_array* res, const mlx_array a, const mlx_array start, const int* axes, size_t axes_num, const int* slice_size, size_t slice_size_num, const mlx_stream s) = NULL; +int (*mlx_slice_update_ptr)(mlx_array* res, const mlx_array src, const mlx_array update, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s) = NULL; +int (*mlx_slice_update_dynamic_ptr)(mlx_array* res, const mlx_array src, const mlx_array update, const mlx_array start, const int* axes, size_t axes_num, const mlx_stream s) = NULL; +int (*mlx_softmax_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool precise, const mlx_stream s) = NULL; +int (*mlx_softmax_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool precise, const mlx_stream s) = NULL; +int (*mlx_softmax_ptr)(mlx_array* res, const mlx_array a, bool precise, const mlx_stream s) = NULL; +int (*mlx_sort_axis_ptr)(mlx_array* res, const mlx_array a, int axis, const mlx_stream s) = NULL; +int (*mlx_sort_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_split_ptr)(mlx_vector_array* res, const mlx_array a, int num_splits, int axis, const mlx_stream s) = NULL; +int (*mlx_split_sections_ptr)(mlx_vector_array* res, const mlx_array a, const int* indices, size_t indices_num, int axis, const mlx_stream s) = NULL; +int (*mlx_sqrt_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_square_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_squeeze_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) = NULL; +int (*mlx_squeeze_axis_ptr)(mlx_array* res, const mlx_array a, int axis, const mlx_stream s) = NULL; +int (*mlx_squeeze_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_stack_axis_ptr)(mlx_array* res, const mlx_vector_array arrays, int axis, const mlx_stream s) = NULL; +int (*mlx_stack_ptr)(mlx_array* res, const mlx_vector_array arrays, const mlx_stream s) = NULL; +int (*mlx_std_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, int ddof, const mlx_stream s) = NULL; +int (*mlx_std_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, int ddof, const mlx_stream s) = NULL; +int (*mlx_std_ptr)(mlx_array* res, const mlx_array a, bool keepdims, int ddof, const mlx_stream s) = NULL; +int (*mlx_stop_gradient_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_subtract_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL; +int (*mlx_sum_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL; +int (*mlx_sum_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL; +int (*mlx_sum_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL; +int (*mlx_swapaxes_ptr)(mlx_array* res, const mlx_array a, int axis1, int axis2, const mlx_stream s) = NULL; +int (*mlx_take_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const mlx_stream s) = NULL; +int (*mlx_take_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_stream s) = NULL; +int (*mlx_take_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const mlx_stream s) = NULL; +int (*mlx_tan_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_tanh_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_tensordot_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const int* axes_a, size_t axes_a_num, const int* axes_b, size_t axes_b_num, const mlx_stream s) = NULL; +int (*mlx_tensordot_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int axis, const mlx_stream s) = NULL; +int (*mlx_tile_ptr)(mlx_array* res, const mlx_array arr, const int* reps, size_t reps_num, const mlx_stream s) = NULL; +int (*mlx_to_fp8_ptr)(mlx_array* res, const mlx_array x, const mlx_stream s) = NULL; +int (*mlx_topk_axis_ptr)(mlx_array* res, const mlx_array a, int k, int axis, const mlx_stream s) = NULL; +int (*mlx_topk_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL; +int (*mlx_trace_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, mlx_dtype dtype, const mlx_stream s) = NULL; +int (*mlx_transpose_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) = NULL; +int (*mlx_transpose_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_tri_ptr)(mlx_array* res, int n, int m, int k, mlx_dtype type, const mlx_stream s) = NULL; +int (*mlx_tril_ptr)(mlx_array* res, const mlx_array x, int k, const mlx_stream s) = NULL; +int (*mlx_triu_ptr)(mlx_array* res, const mlx_array x, int k, const mlx_stream s) = NULL; +int (*mlx_unflatten_ptr)(mlx_array* res, const mlx_array a, int axis, const int* shape, size_t shape_num, const mlx_stream s) = NULL; +int (*mlx_var_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, int ddof, const mlx_stream s) = NULL; +int (*mlx_var_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, int ddof, const mlx_stream s) = NULL; +int (*mlx_var_ptr)(mlx_array* res, const mlx_array a, bool keepdims, int ddof, const mlx_stream s) = NULL; +int (*mlx_view_ptr)(mlx_array* res, const mlx_array a, mlx_dtype dtype, const mlx_stream s) = NULL; +int (*mlx_where_ptr)(mlx_array* res, const mlx_array condition, const mlx_array x, const mlx_array y, const mlx_stream s) = NULL; +int (*mlx_zeros_ptr)(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_stream s) = NULL; +int (*mlx_zeros_like_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL; +int (*mlx_random_bernoulli_ptr)(mlx_array* res, const mlx_array p, const int* shape, size_t shape_num, const mlx_array key , const mlx_stream s) = NULL; +int (*mlx_random_bits_ptr)(mlx_array* res, const int* shape, size_t shape_num, int width, const mlx_array key , const mlx_stream s) = NULL; +int (*mlx_random_categorical_shape_ptr)(mlx_array* res, const mlx_array logits, int axis, const int* shape, size_t shape_num, const mlx_array key , const mlx_stream s) = NULL; +int (*mlx_random_categorical_num_samples_ptr)(mlx_array* res, const mlx_array logits_, int axis, int num_samples, const mlx_array key , const mlx_stream s) = NULL; +int (*mlx_random_categorical_ptr)(mlx_array* res, const mlx_array logits, int axis, const mlx_array key , const mlx_stream s) = NULL; +int (*mlx_random_gumbel_ptr)(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) = NULL; +int (*mlx_random_key_ptr)(mlx_array* res, uint64_t seed) = NULL; +int (*mlx_random_laplace_ptr)(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, float loc, float scale, const mlx_array key , const mlx_stream s) = NULL; +int (*mlx_random_multivariate_normal_ptr)(mlx_array* res, const mlx_array mean, const mlx_array cov, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) = NULL; +int (*mlx_random_normal_broadcast_ptr)(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array loc , const mlx_array scale , const mlx_array key , const mlx_stream s) = NULL; +int (*mlx_random_normal_ptr)(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, float loc, float scale, const mlx_array key , const mlx_stream s) = NULL; +int (*mlx_random_permutation_ptr)(mlx_array* res, const mlx_array x, int axis, const mlx_array key , const mlx_stream s) = NULL; +int (*mlx_random_permutation_arange_ptr)(mlx_array* res, int x, const mlx_array key , const mlx_stream s) = NULL; +int (*mlx_random_randint_ptr)(mlx_array* res, const mlx_array low, const mlx_array high, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) = NULL; +int (*mlx_random_seed_ptr)(uint64_t seed) = NULL; +int (*mlx_random_split_num_ptr)(mlx_array* res, const mlx_array key, int num, const mlx_stream s) = NULL; +int (*mlx_random_split_ptr)(mlx_array* res_0, mlx_array* res_1, const mlx_array key, const mlx_stream s) = NULL; +int (*mlx_random_truncated_normal_ptr)(mlx_array* res, const mlx_array lower, const mlx_array upper, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) = NULL; +int (*mlx_random_uniform_ptr)(mlx_array* res, const mlx_array low, const mlx_array high, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) = NULL; +mlx_stream (*mlx_stream_new_ptr)(void) = NULL; +mlx_stream (*mlx_stream_new_device_ptr)(mlx_device dev) = NULL; +int (*mlx_stream_set_ptr)(mlx_stream* stream, const mlx_stream src) = NULL; +int (*mlx_stream_free_ptr)(mlx_stream stream) = NULL; +int (*mlx_stream_tostring_ptr)(mlx_string* str, mlx_stream stream) = NULL; +bool (*mlx_stream_equal_ptr)(mlx_stream lhs, mlx_stream rhs) = NULL; +int (*mlx_stream_get_device_ptr)(mlx_device* dev, mlx_stream stream) = NULL; +int (*mlx_stream_get_index_ptr)(int* index, mlx_stream stream) = NULL; +int (*mlx_synchronize_ptr)(mlx_stream stream) = NULL; +int (*mlx_get_default_stream_ptr)(mlx_stream* stream, mlx_device dev) = NULL; +int (*mlx_set_default_stream_ptr)(mlx_stream stream) = NULL; +mlx_stream (*mlx_default_cpu_stream_new_ptr)(void) = NULL; +mlx_stream (*mlx_default_gpu_stream_new_ptr)(void) = NULL; +mlx_string (*mlx_string_new_ptr)(void) = NULL; +mlx_string (*mlx_string_new_data_ptr)(const char* str) = NULL; +int (*mlx_string_set_ptr)(mlx_string* str, const mlx_string src) = NULL; +const char* (*mlx_string_data_ptr)(mlx_string str) = NULL; +int (*mlx_string_free_ptr)(mlx_string str) = NULL; +int (*mlx_async_eval_ptr)(const mlx_vector_array outputs) = NULL; +int (*mlx_checkpoint_ptr)(mlx_closure* res, const mlx_closure fun) = NULL; +int (*mlx_custom_function_ptr)(mlx_closure* res, const mlx_closure fun, const mlx_closure_custom fun_vjp , const mlx_closure_custom_jvp fun_jvp , const mlx_closure_custom_vmap fun_vmap) = NULL; +int (*mlx_custom_vjp_ptr)(mlx_closure* res, const mlx_closure fun, const mlx_closure_custom fun_vjp) = NULL; +int (*mlx_eval_ptr)(const mlx_vector_array outputs) = NULL; +int (*mlx_jvp_ptr)(mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array primals, const mlx_vector_array tangents) = NULL; +int (*mlx_value_and_grad_ptr)(mlx_closure_value_and_grad* res, const mlx_closure fun, const int* argnums, size_t argnums_num) = NULL; +int (*mlx_vjp_ptr)(mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array primals, const mlx_vector_array cotangents) = NULL; +int (*mlx_detail_vmap_replace_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array s_inputs, const mlx_vector_array s_outputs, const int* in_axes, size_t in_axes_num, const int* out_axes, size_t out_axes_num) = NULL; +int (*mlx_detail_vmap_trace_ptr)(mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array inputs, const int* in_axes, size_t in_axes_num) = NULL; +mlx_vector_array (*mlx_vector_array_new_ptr)(void) = NULL; +int (*mlx_vector_array_set_ptr)(mlx_vector_array* vec, const mlx_vector_array src) = NULL; +int (*mlx_vector_array_free_ptr)(mlx_vector_array vec) = NULL; +mlx_vector_array (*mlx_vector_array_new_data_ptr)(const mlx_array* data, size_t size) = NULL; +mlx_vector_array (*mlx_vector_array_new_value_ptr)(const mlx_array val) = NULL; +int (*mlx_vector_array_set_data_ptr)(mlx_vector_array* vec, const mlx_array* data, size_t size) = NULL; +int (*mlx_vector_array_set_value_ptr)(mlx_vector_array* vec, const mlx_array val) = NULL; +int (*mlx_vector_array_append_data_ptr)(mlx_vector_array vec, const mlx_array* data, size_t size) = NULL; +int (*mlx_vector_array_append_value_ptr)(mlx_vector_array vec, const mlx_array val) = NULL; +size_t (*mlx_vector_array_size_ptr)(mlx_vector_array vec) = NULL; +int (*mlx_vector_array_get_ptr)(mlx_array* res, const mlx_vector_array vec, size_t idx) = NULL; +mlx_vector_vector_array (*mlx_vector_vector_array_new_ptr)(void) = NULL; +int (*mlx_vector_vector_array_set_ptr)(mlx_vector_vector_array* vec, const mlx_vector_vector_array src) = NULL; +int (*mlx_vector_vector_array_free_ptr)(mlx_vector_vector_array vec) = NULL; +mlx_vector_vector_array (*mlx_vector_vector_array_new_data_ptr)(const mlx_vector_array* data, size_t size) = NULL; +mlx_vector_vector_array (*mlx_vector_vector_array_new_value_ptr)(const mlx_vector_array val) = NULL; +int (*mlx_vector_vector_array_set_data_ptr)(mlx_vector_vector_array* vec, const mlx_vector_array* data, size_t size) = NULL; +int (*mlx_vector_vector_array_set_value_ptr)(mlx_vector_vector_array* vec, const mlx_vector_array val) = NULL; +int (*mlx_vector_vector_array_append_data_ptr)(mlx_vector_vector_array vec, const mlx_vector_array* data, size_t size) = NULL; +int (*mlx_vector_vector_array_append_value_ptr)(mlx_vector_vector_array vec, const mlx_vector_array val) = NULL; +size_t (*mlx_vector_vector_array_size_ptr)(mlx_vector_vector_array vec) = NULL; +int (*mlx_vector_vector_array_get_ptr)(mlx_vector_array* res, const mlx_vector_vector_array vec, size_t idx) = NULL; +mlx_vector_int (*mlx_vector_int_new_ptr)(void) = NULL; +int (*mlx_vector_int_set_ptr)(mlx_vector_int* vec, const mlx_vector_int src) = NULL; +int (*mlx_vector_int_free_ptr)(mlx_vector_int vec) = NULL; +mlx_vector_int (*mlx_vector_int_new_data_ptr)(int* data, size_t size) = NULL; +mlx_vector_int (*mlx_vector_int_new_value_ptr)(int val) = NULL; +int (*mlx_vector_int_set_data_ptr)(mlx_vector_int* vec, int* data, size_t size) = NULL; +int (*mlx_vector_int_set_value_ptr)(mlx_vector_int* vec, int val) = NULL; +int (*mlx_vector_int_append_data_ptr)(mlx_vector_int vec, int* data, size_t size) = NULL; +int (*mlx_vector_int_append_value_ptr)(mlx_vector_int vec, int val) = NULL; +size_t (*mlx_vector_int_size_ptr)(mlx_vector_int vec) = NULL; +int (*mlx_vector_int_get_ptr)(int* res, const mlx_vector_int vec, size_t idx) = NULL; +mlx_vector_string (*mlx_vector_string_new_ptr)(void) = NULL; +int (*mlx_vector_string_set_ptr)(mlx_vector_string* vec, const mlx_vector_string src) = NULL; +int (*mlx_vector_string_free_ptr)(mlx_vector_string vec) = NULL; +mlx_vector_string (*mlx_vector_string_new_data_ptr)(const char** data, size_t size) = NULL; +mlx_vector_string (*mlx_vector_string_new_value_ptr)(const char* val) = NULL; +int (*mlx_vector_string_set_data_ptr)(mlx_vector_string* vec, const char** data, size_t size) = NULL; +int (*mlx_vector_string_set_value_ptr)(mlx_vector_string* vec, const char* val) = NULL; +int (*mlx_vector_string_append_data_ptr)(mlx_vector_string vec, const char** data, size_t size) = NULL; +int (*mlx_vector_string_append_value_ptr)(mlx_vector_string vec, const char* val) = NULL; +size_t (*mlx_vector_string_size_ptr)(mlx_vector_string vec) = NULL; +int (*mlx_vector_string_get_ptr)(char** res, const mlx_vector_string vec, size_t idx) = NULL; +int (*mlx_version_ptr)(mlx_string* str_) = NULL; + +// Initialize all function pointers via dlsym +int mlx_load_functions(void* handle) { + if (handle == NULL) { + fprintf(stderr, "MLX: Invalid library handle\n"); + return -1; + } + + mlx_dtype_size_ptr = dlsym(handle, "mlx_dtype_size"); + if (mlx_dtype_size_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_dtype_size\n"); + return -1; + } + mlx_array_tostring_ptr = dlsym(handle, "mlx_array_tostring"); + if (mlx_array_tostring_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_tostring\n"); + return -1; + } + mlx_array_new_ptr = dlsym(handle, "mlx_array_new"); + if (mlx_array_new_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new\n"); + return -1; + } + mlx_array_free_ptr = dlsym(handle, "mlx_array_free"); + if (mlx_array_free_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_free\n"); + return -1; + } + mlx_array_new_bool_ptr = dlsym(handle, "mlx_array_new_bool"); + if (mlx_array_new_bool_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_bool\n"); + return -1; + } + mlx_array_new_int_ptr = dlsym(handle, "mlx_array_new_int"); + if (mlx_array_new_int_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_int\n"); + return -1; + } + mlx_array_new_float32_ptr = dlsym(handle, "mlx_array_new_float32"); + if (mlx_array_new_float32_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_float32\n"); + return -1; + } + mlx_array_new_float_ptr = dlsym(handle, "mlx_array_new_float"); + if (mlx_array_new_float_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_float\n"); + return -1; + } + mlx_array_new_float64_ptr = dlsym(handle, "mlx_array_new_float64"); + if (mlx_array_new_float64_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_float64\n"); + return -1; + } + mlx_array_new_double_ptr = dlsym(handle, "mlx_array_new_double"); + if (mlx_array_new_double_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_double\n"); + return -1; + } + mlx_array_new_complex_ptr = dlsym(handle, "mlx_array_new_complex"); + if (mlx_array_new_complex_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_complex\n"); + return -1; + } + mlx_array_new_data_ptr = dlsym(handle, "mlx_array_new_data"); + if (mlx_array_new_data_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data\n"); + return -1; + } + mlx_array_set_ptr = dlsym(handle, "mlx_array_set"); + if (mlx_array_set_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set\n"); + return -1; + } + mlx_array_set_bool_ptr = dlsym(handle, "mlx_array_set_bool"); + if (mlx_array_set_bool_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_bool\n"); + return -1; + } + mlx_array_set_int_ptr = dlsym(handle, "mlx_array_set_int"); + if (mlx_array_set_int_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_int\n"); + return -1; + } + mlx_array_set_float32_ptr = dlsym(handle, "mlx_array_set_float32"); + if (mlx_array_set_float32_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_float32\n"); + return -1; + } + mlx_array_set_float_ptr = dlsym(handle, "mlx_array_set_float"); + if (mlx_array_set_float_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_float\n"); + return -1; + } + mlx_array_set_float64_ptr = dlsym(handle, "mlx_array_set_float64"); + if (mlx_array_set_float64_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_float64\n"); + return -1; + } + mlx_array_set_double_ptr = dlsym(handle, "mlx_array_set_double"); + if (mlx_array_set_double_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_double\n"); + return -1; + } + mlx_array_set_complex_ptr = dlsym(handle, "mlx_array_set_complex"); + if (mlx_array_set_complex_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_complex\n"); + return -1; + } + mlx_array_set_data_ptr = dlsym(handle, "mlx_array_set_data"); + if (mlx_array_set_data_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_data\n"); + return -1; + } + mlx_array_itemsize_ptr = dlsym(handle, "mlx_array_itemsize"); + if (mlx_array_itemsize_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_itemsize\n"); + return -1; + } + mlx_array_size_ptr = dlsym(handle, "mlx_array_size"); + if (mlx_array_size_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_size\n"); + return -1; + } + mlx_array_nbytes_ptr = dlsym(handle, "mlx_array_nbytes"); + if (mlx_array_nbytes_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_nbytes\n"); + return -1; + } + mlx_array_ndim_ptr = dlsym(handle, "mlx_array_ndim"); + if (mlx_array_ndim_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_ndim\n"); + return -1; + } + mlx_array_shape_ptr = dlsym(handle, "mlx_array_shape"); + if (mlx_array_shape_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_shape\n"); + return -1; + } + mlx_array_strides_ptr = dlsym(handle, "mlx_array_strides"); + if (mlx_array_strides_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_strides\n"); + return -1; + } + mlx_array_dim_ptr = dlsym(handle, "mlx_array_dim"); + if (mlx_array_dim_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_dim\n"); + return -1; + } + mlx_array_dtype_ptr = dlsym(handle, "mlx_array_dtype"); + if (mlx_array_dtype_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_dtype\n"); + return -1; + } + mlx_array_eval_ptr = dlsym(handle, "mlx_array_eval"); + if (mlx_array_eval_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_eval\n"); + return -1; + } + mlx_array_item_bool_ptr = dlsym(handle, "mlx_array_item_bool"); + if (mlx_array_item_bool_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_bool\n"); + return -1; + } + mlx_array_item_uint8_ptr = dlsym(handle, "mlx_array_item_uint8"); + if (mlx_array_item_uint8_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_uint8\n"); + return -1; + } + mlx_array_item_uint16_ptr = dlsym(handle, "mlx_array_item_uint16"); + if (mlx_array_item_uint16_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_uint16\n"); + return -1; + } + mlx_array_item_uint32_ptr = dlsym(handle, "mlx_array_item_uint32"); + if (mlx_array_item_uint32_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_uint32\n"); + return -1; + } + mlx_array_item_uint64_ptr = dlsym(handle, "mlx_array_item_uint64"); + if (mlx_array_item_uint64_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_uint64\n"); + return -1; + } + mlx_array_item_int8_ptr = dlsym(handle, "mlx_array_item_int8"); + if (mlx_array_item_int8_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_int8\n"); + return -1; + } + mlx_array_item_int16_ptr = dlsym(handle, "mlx_array_item_int16"); + if (mlx_array_item_int16_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_int16\n"); + return -1; + } + mlx_array_item_int32_ptr = dlsym(handle, "mlx_array_item_int32"); + if (mlx_array_item_int32_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_int32\n"); + return -1; + } + mlx_array_item_int64_ptr = dlsym(handle, "mlx_array_item_int64"); + if (mlx_array_item_int64_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_int64\n"); + return -1; + } + mlx_array_item_float32_ptr = dlsym(handle, "mlx_array_item_float32"); + if (mlx_array_item_float32_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_float32\n"); + return -1; + } + mlx_array_item_float64_ptr = dlsym(handle, "mlx_array_item_float64"); + if (mlx_array_item_float64_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_float64\n"); + return -1; + } + mlx_array_item_complex64_ptr = dlsym(handle, "mlx_array_item_complex64"); + if (mlx_array_item_complex64_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_complex64\n"); + return -1; + } +#if defined(__aarch64__) || defined(_M_ARM64) + mlx_array_item_float16_ptr = dlsym(handle, "mlx_array_item_float16"); + if (mlx_array_item_float16_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_float16\n"); + return -1; + } +#endif +#if defined(__aarch64__) || defined(_M_ARM64) + mlx_array_item_bfloat16_ptr = dlsym(handle, "mlx_array_item_bfloat16"); + if (mlx_array_item_bfloat16_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_bfloat16\n"); + return -1; + } +#endif + mlx_array_data_bool_ptr = dlsym(handle, "mlx_array_data_bool"); + if (mlx_array_data_bool_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_bool\n"); + return -1; + } + mlx_array_data_uint8_ptr = dlsym(handle, "mlx_array_data_uint8"); + if (mlx_array_data_uint8_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_uint8\n"); + return -1; + } + mlx_array_data_uint16_ptr = dlsym(handle, "mlx_array_data_uint16"); + if (mlx_array_data_uint16_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_uint16\n"); + return -1; + } + mlx_array_data_uint32_ptr = dlsym(handle, "mlx_array_data_uint32"); + if (mlx_array_data_uint32_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_uint32\n"); + return -1; + } + mlx_array_data_uint64_ptr = dlsym(handle, "mlx_array_data_uint64"); + if (mlx_array_data_uint64_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_uint64\n"); + return -1; + } + mlx_array_data_int8_ptr = dlsym(handle, "mlx_array_data_int8"); + if (mlx_array_data_int8_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_int8\n"); + return -1; + } + mlx_array_data_int16_ptr = dlsym(handle, "mlx_array_data_int16"); + if (mlx_array_data_int16_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_int16\n"); + return -1; + } + mlx_array_data_int32_ptr = dlsym(handle, "mlx_array_data_int32"); + if (mlx_array_data_int32_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_int32\n"); + return -1; + } + mlx_array_data_int64_ptr = dlsym(handle, "mlx_array_data_int64"); + if (mlx_array_data_int64_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_int64\n"); + return -1; + } + mlx_array_data_float32_ptr = dlsym(handle, "mlx_array_data_float32"); + if (mlx_array_data_float32_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_float32\n"); + return -1; + } + mlx_array_data_float64_ptr = dlsym(handle, "mlx_array_data_float64"); + if (mlx_array_data_float64_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_float64\n"); + return -1; + } + mlx_array_data_complex64_ptr = dlsym(handle, "mlx_array_data_complex64"); + if (mlx_array_data_complex64_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_complex64\n"); + return -1; + } +#if defined(__aarch64__) || defined(_M_ARM64) + mlx_array_data_float16_ptr = dlsym(handle, "mlx_array_data_float16"); + if (mlx_array_data_float16_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_float16\n"); + return -1; + } +#endif +#if defined(__aarch64__) || defined(_M_ARM64) + mlx_array_data_bfloat16_ptr = dlsym(handle, "mlx_array_data_bfloat16"); + if (mlx_array_data_bfloat16_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_bfloat16\n"); + return -1; + } +#endif + _mlx_array_is_available_ptr = dlsym(handle, "_mlx_array_is_available"); + if (_mlx_array_is_available_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: _mlx_array_is_available\n"); + return -1; + } + _mlx_array_wait_ptr = dlsym(handle, "_mlx_array_wait"); + if (_mlx_array_wait_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: _mlx_array_wait\n"); + return -1; + } + _mlx_array_is_contiguous_ptr = dlsym(handle, "_mlx_array_is_contiguous"); + if (_mlx_array_is_contiguous_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: _mlx_array_is_contiguous\n"); + return -1; + } + _mlx_array_is_row_contiguous_ptr = dlsym(handle, "_mlx_array_is_row_contiguous"); + if (_mlx_array_is_row_contiguous_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: _mlx_array_is_row_contiguous\n"); + return -1; + } + _mlx_array_is_col_contiguous_ptr = dlsym(handle, "_mlx_array_is_col_contiguous"); + if (_mlx_array_is_col_contiguous_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: _mlx_array_is_col_contiguous\n"); + return -1; + } + mlx_closure_new_ptr = dlsym(handle, "mlx_closure_new"); + if (mlx_closure_new_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_new\n"); + return -1; + } + mlx_closure_free_ptr = dlsym(handle, "mlx_closure_free"); + if (mlx_closure_free_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_free\n"); + return -1; + } + mlx_closure_new_func_ptr = dlsym(handle, "mlx_closure_new_func"); + if (mlx_closure_new_func_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_new_func\n"); + return -1; + } + mlx_closure_new_func_payload_ptr = dlsym(handle, "mlx_closure_new_func_payload"); + if (mlx_closure_new_func_payload_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_new_func_payload\n"); + return -1; + } + mlx_closure_set_ptr = dlsym(handle, "mlx_closure_set"); + if (mlx_closure_set_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_set\n"); + return -1; + } + mlx_closure_apply_ptr = dlsym(handle, "mlx_closure_apply"); + if (mlx_closure_apply_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_apply\n"); + return -1; + } + mlx_closure_new_unary_ptr = dlsym(handle, "mlx_closure_new_unary"); + if (mlx_closure_new_unary_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_new_unary\n"); + return -1; + } + mlx_closure_kwargs_new_ptr = dlsym(handle, "mlx_closure_kwargs_new"); + if (mlx_closure_kwargs_new_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_new\n"); + return -1; + } + mlx_closure_kwargs_free_ptr = dlsym(handle, "mlx_closure_kwargs_free"); + if (mlx_closure_kwargs_free_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_free\n"); + return -1; + } + mlx_closure_kwargs_new_func_ptr = dlsym(handle, "mlx_closure_kwargs_new_func"); + if (mlx_closure_kwargs_new_func_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_new_func\n"); + return -1; + } + mlx_closure_kwargs_new_func_payload_ptr = dlsym(handle, "mlx_closure_kwargs_new_func_payload"); + if (mlx_closure_kwargs_new_func_payload_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_new_func_payload\n"); + return -1; + } + mlx_closure_kwargs_set_ptr = dlsym(handle, "mlx_closure_kwargs_set"); + if (mlx_closure_kwargs_set_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_set\n"); + return -1; + } + mlx_closure_kwargs_apply_ptr = dlsym(handle, "mlx_closure_kwargs_apply"); + if (mlx_closure_kwargs_apply_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_apply\n"); + return -1; + } + mlx_closure_value_and_grad_new_ptr = dlsym(handle, "mlx_closure_value_and_grad_new"); + if (mlx_closure_value_and_grad_new_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_new\n"); + return -1; + } + mlx_closure_value_and_grad_free_ptr = dlsym(handle, "mlx_closure_value_and_grad_free"); + if (mlx_closure_value_and_grad_free_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_free\n"); + return -1; + } + mlx_closure_value_and_grad_new_func_ptr = dlsym(handle, "mlx_closure_value_and_grad_new_func"); + if (mlx_closure_value_and_grad_new_func_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_new_func\n"); + return -1; + } + mlx_closure_value_and_grad_new_func_payload_ptr = dlsym(handle, "mlx_closure_value_and_grad_new_func_payload"); + if (mlx_closure_value_and_grad_new_func_payload_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_new_func_payload\n"); + return -1; + } + mlx_closure_value_and_grad_set_ptr = dlsym(handle, "mlx_closure_value_and_grad_set"); + if (mlx_closure_value_and_grad_set_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_set\n"); + return -1; + } + mlx_closure_value_and_grad_apply_ptr = dlsym(handle, "mlx_closure_value_and_grad_apply"); + if (mlx_closure_value_and_grad_apply_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_apply\n"); + return -1; + } + mlx_closure_custom_new_ptr = dlsym(handle, "mlx_closure_custom_new"); + if (mlx_closure_custom_new_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_new\n"); + return -1; + } + mlx_closure_custom_free_ptr = dlsym(handle, "mlx_closure_custom_free"); + if (mlx_closure_custom_free_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_free\n"); + return -1; + } + mlx_closure_custom_new_func_ptr = dlsym(handle, "mlx_closure_custom_new_func"); + if (mlx_closure_custom_new_func_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_new_func\n"); + return -1; + } + mlx_closure_custom_new_func_payload_ptr = dlsym(handle, "mlx_closure_custom_new_func_payload"); + if (mlx_closure_custom_new_func_payload_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_new_func_payload\n"); + return -1; + } + mlx_closure_custom_set_ptr = dlsym(handle, "mlx_closure_custom_set"); + if (mlx_closure_custom_set_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_set\n"); + return -1; + } + mlx_closure_custom_apply_ptr = dlsym(handle, "mlx_closure_custom_apply"); + if (mlx_closure_custom_apply_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_apply\n"); + return -1; + } + mlx_closure_custom_jvp_new_ptr = dlsym(handle, "mlx_closure_custom_jvp_new"); + if (mlx_closure_custom_jvp_new_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_new\n"); + return -1; + } + mlx_closure_custom_jvp_free_ptr = dlsym(handle, "mlx_closure_custom_jvp_free"); + if (mlx_closure_custom_jvp_free_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_free\n"); + return -1; + } + mlx_closure_custom_jvp_new_func_ptr = dlsym(handle, "mlx_closure_custom_jvp_new_func"); + if (mlx_closure_custom_jvp_new_func_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_new_func\n"); + return -1; + } + mlx_closure_custom_jvp_new_func_payload_ptr = dlsym(handle, "mlx_closure_custom_jvp_new_func_payload"); + if (mlx_closure_custom_jvp_new_func_payload_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_new_func_payload\n"); + return -1; + } + mlx_closure_custom_jvp_set_ptr = dlsym(handle, "mlx_closure_custom_jvp_set"); + if (mlx_closure_custom_jvp_set_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_set\n"); + return -1; + } + mlx_closure_custom_jvp_apply_ptr = dlsym(handle, "mlx_closure_custom_jvp_apply"); + if (mlx_closure_custom_jvp_apply_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_apply\n"); + return -1; + } + mlx_closure_custom_vmap_new_ptr = dlsym(handle, "mlx_closure_custom_vmap_new"); + if (mlx_closure_custom_vmap_new_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_new\n"); + return -1; + } + mlx_closure_custom_vmap_free_ptr = dlsym(handle, "mlx_closure_custom_vmap_free"); + if (mlx_closure_custom_vmap_free_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_free\n"); + return -1; + } + mlx_closure_custom_vmap_new_func_ptr = dlsym(handle, "mlx_closure_custom_vmap_new_func"); + if (mlx_closure_custom_vmap_new_func_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_new_func\n"); + return -1; + } + mlx_closure_custom_vmap_new_func_payload_ptr = dlsym(handle, "mlx_closure_custom_vmap_new_func_payload"); + if (mlx_closure_custom_vmap_new_func_payload_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_new_func_payload\n"); + return -1; + } + mlx_closure_custom_vmap_set_ptr = dlsym(handle, "mlx_closure_custom_vmap_set"); + if (mlx_closure_custom_vmap_set_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_set\n"); + return -1; + } + mlx_closure_custom_vmap_apply_ptr = dlsym(handle, "mlx_closure_custom_vmap_apply"); + if (mlx_closure_custom_vmap_apply_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_apply\n"); + return -1; + } + mlx_compile_ptr = dlsym(handle, "mlx_compile"); + if (mlx_compile_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_compile\n"); + return -1; + } + mlx_detail_compile_ptr = dlsym(handle, "mlx_detail_compile"); + if (mlx_detail_compile_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_detail_compile\n"); + return -1; + } + mlx_detail_compile_clear_cache_ptr = dlsym(handle, "mlx_detail_compile_clear_cache"); + if (mlx_detail_compile_clear_cache_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_detail_compile_clear_cache\n"); + return -1; + } + mlx_detail_compile_erase_ptr = dlsym(handle, "mlx_detail_compile_erase"); + if (mlx_detail_compile_erase_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_detail_compile_erase\n"); + return -1; + } + mlx_disable_compile_ptr = dlsym(handle, "mlx_disable_compile"); + if (mlx_disable_compile_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_disable_compile\n"); + return -1; + } + mlx_enable_compile_ptr = dlsym(handle, "mlx_enable_compile"); + if (mlx_enable_compile_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_enable_compile\n"); + return -1; + } + mlx_set_compile_mode_ptr = dlsym(handle, "mlx_set_compile_mode"); + if (mlx_set_compile_mode_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_set_compile_mode\n"); + return -1; + } + mlx_device_new_ptr = dlsym(handle, "mlx_device_new"); + if (mlx_device_new_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_device_new\n"); + return -1; + } + mlx_device_new_type_ptr = dlsym(handle, "mlx_device_new_type"); + if (mlx_device_new_type_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_device_new_type\n"); + return -1; + } + mlx_device_free_ptr = dlsym(handle, "mlx_device_free"); + if (mlx_device_free_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_device_free\n"); + return -1; + } + mlx_device_set_ptr = dlsym(handle, "mlx_device_set"); + if (mlx_device_set_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_device_set\n"); + return -1; + } + mlx_device_tostring_ptr = dlsym(handle, "mlx_device_tostring"); + if (mlx_device_tostring_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_device_tostring\n"); + return -1; + } + mlx_device_equal_ptr = dlsym(handle, "mlx_device_equal"); + if (mlx_device_equal_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_device_equal\n"); + return -1; + } + mlx_device_get_index_ptr = dlsym(handle, "mlx_device_get_index"); + if (mlx_device_get_index_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_device_get_index\n"); + return -1; + } + mlx_device_get_type_ptr = dlsym(handle, "mlx_device_get_type"); + if (mlx_device_get_type_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_device_get_type\n"); + return -1; + } + mlx_get_default_device_ptr = dlsym(handle, "mlx_get_default_device"); + if (mlx_get_default_device_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_get_default_device\n"); + return -1; + } + mlx_set_default_device_ptr = dlsym(handle, "mlx_set_default_device"); + if (mlx_set_default_device_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_set_default_device\n"); + return -1; + } + mlx_distributed_all_gather_ptr = dlsym(handle, "mlx_distributed_all_gather"); + if (mlx_distributed_all_gather_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_gather\n"); + return -1; + } + mlx_distributed_all_max_ptr = dlsym(handle, "mlx_distributed_all_max"); + if (mlx_distributed_all_max_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_max\n"); + return -1; + } + mlx_distributed_all_min_ptr = dlsym(handle, "mlx_distributed_all_min"); + if (mlx_distributed_all_min_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_min\n"); + return -1; + } + mlx_distributed_all_sum_ptr = dlsym(handle, "mlx_distributed_all_sum"); + if (mlx_distributed_all_sum_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_sum\n"); + return -1; + } + mlx_distributed_recv_ptr = dlsym(handle, "mlx_distributed_recv"); + if (mlx_distributed_recv_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_recv\n"); + return -1; + } + mlx_distributed_recv_like_ptr = dlsym(handle, "mlx_distributed_recv_like"); + if (mlx_distributed_recv_like_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_recv_like\n"); + return -1; + } + mlx_distributed_send_ptr = dlsym(handle, "mlx_distributed_send"); + if (mlx_distributed_send_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_send\n"); + return -1; + } + mlx_distributed_sum_scatter_ptr = dlsym(handle, "mlx_distributed_sum_scatter"); + if (mlx_distributed_sum_scatter_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_sum_scatter\n"); + return -1; + } + mlx_distributed_group_rank_ptr = dlsym(handle, "mlx_distributed_group_rank"); + if (mlx_distributed_group_rank_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_group_rank\n"); + return -1; + } + mlx_distributed_group_size_ptr = dlsym(handle, "mlx_distributed_group_size"); + if (mlx_distributed_group_size_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_group_size\n"); + return -1; + } + mlx_distributed_group_split_ptr = dlsym(handle, "mlx_distributed_group_split"); + if (mlx_distributed_group_split_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_group_split\n"); + return -1; + } + mlx_distributed_is_available_ptr = dlsym(handle, "mlx_distributed_is_available"); + if (mlx_distributed_is_available_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_is_available\n"); + return -1; + } + mlx_distributed_init_ptr = dlsym(handle, "mlx_distributed_init"); + if (mlx_distributed_init_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_init\n"); + return -1; + } + mlx_set_error_handler_ptr = dlsym(handle, "mlx_set_error_handler"); + if (mlx_set_error_handler_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_set_error_handler\n"); + return -1; + } + _mlx_error_ptr = dlsym(handle, "_mlx_error"); + if (_mlx_error_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: _mlx_error\n"); + return -1; + } + mlx_export_function_ptr = dlsym(handle, "mlx_export_function"); + if (mlx_export_function_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_export_function\n"); + return -1; + } + mlx_export_function_kwargs_ptr = dlsym(handle, "mlx_export_function_kwargs"); + if (mlx_export_function_kwargs_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_export_function_kwargs\n"); + return -1; + } + mlx_function_exporter_new_ptr = dlsym(handle, "mlx_function_exporter_new"); + if (mlx_function_exporter_new_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_function_exporter_new\n"); + return -1; + } + mlx_function_exporter_free_ptr = dlsym(handle, "mlx_function_exporter_free"); + if (mlx_function_exporter_free_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_function_exporter_free\n"); + return -1; + } + mlx_function_exporter_apply_ptr = dlsym(handle, "mlx_function_exporter_apply"); + if (mlx_function_exporter_apply_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_function_exporter_apply\n"); + return -1; + } + mlx_function_exporter_apply_kwargs_ptr = dlsym(handle, "mlx_function_exporter_apply_kwargs"); + if (mlx_function_exporter_apply_kwargs_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_function_exporter_apply_kwargs\n"); + return -1; + } + mlx_imported_function_new_ptr = dlsym(handle, "mlx_imported_function_new"); + if (mlx_imported_function_new_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_imported_function_new\n"); + return -1; + } + mlx_imported_function_free_ptr = dlsym(handle, "mlx_imported_function_free"); + if (mlx_imported_function_free_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_imported_function_free\n"); + return -1; + } + mlx_imported_function_apply_ptr = dlsym(handle, "mlx_imported_function_apply"); + if (mlx_imported_function_apply_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_imported_function_apply\n"); + return -1; + } + mlx_imported_function_apply_kwargs_ptr = dlsym(handle, "mlx_imported_function_apply_kwargs"); + if (mlx_imported_function_apply_kwargs_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_imported_function_apply_kwargs\n"); + return -1; + } + mlx_fast_cuda_kernel_config_new_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_new"); + if (mlx_fast_cuda_kernel_config_new_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_new\n"); + return -1; + } + mlx_fast_cuda_kernel_config_free_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_free"); + if (mlx_fast_cuda_kernel_config_free_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_free\n"); + return -1; + } + mlx_fast_cuda_kernel_config_add_output_arg_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_add_output_arg"); + if (mlx_fast_cuda_kernel_config_add_output_arg_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_add_output_arg\n"); + return -1; + } + mlx_fast_cuda_kernel_config_set_grid_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_set_grid"); + if (mlx_fast_cuda_kernel_config_set_grid_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_set_grid\n"); + return -1; + } + mlx_fast_cuda_kernel_config_set_thread_group_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_set_thread_group"); + if (mlx_fast_cuda_kernel_config_set_thread_group_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_set_thread_group\n"); + return -1; + } + mlx_fast_cuda_kernel_config_set_init_value_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_set_init_value"); + if (mlx_fast_cuda_kernel_config_set_init_value_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_set_init_value\n"); + return -1; + } + mlx_fast_cuda_kernel_config_set_verbose_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_set_verbose"); + if (mlx_fast_cuda_kernel_config_set_verbose_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_set_verbose\n"); + return -1; + } + mlx_fast_cuda_kernel_config_add_template_arg_dtype_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_add_template_arg_dtype"); + if (mlx_fast_cuda_kernel_config_add_template_arg_dtype_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_add_template_arg_dtype\n"); + return -1; + } + mlx_fast_cuda_kernel_config_add_template_arg_int_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_add_template_arg_int"); + if (mlx_fast_cuda_kernel_config_add_template_arg_int_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_add_template_arg_int\n"); + return -1; + } + mlx_fast_cuda_kernel_config_add_template_arg_bool_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_add_template_arg_bool"); + if (mlx_fast_cuda_kernel_config_add_template_arg_bool_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_add_template_arg_bool\n"); + return -1; + } + mlx_fast_cuda_kernel_new_ptr = dlsym(handle, "mlx_fast_cuda_kernel_new"); + if (mlx_fast_cuda_kernel_new_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_new\n"); + return -1; + } + mlx_fast_cuda_kernel_free_ptr = dlsym(handle, "mlx_fast_cuda_kernel_free"); + if (mlx_fast_cuda_kernel_free_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_free\n"); + return -1; + } + mlx_fast_cuda_kernel_apply_ptr = dlsym(handle, "mlx_fast_cuda_kernel_apply"); + if (mlx_fast_cuda_kernel_apply_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_apply\n"); + return -1; + } + mlx_fast_layer_norm_ptr = dlsym(handle, "mlx_fast_layer_norm"); + if (mlx_fast_layer_norm_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_layer_norm\n"); + return -1; + } + mlx_fast_metal_kernel_config_new_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_new"); + if (mlx_fast_metal_kernel_config_new_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_new\n"); + return -1; + } + mlx_fast_metal_kernel_config_free_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_free"); + if (mlx_fast_metal_kernel_config_free_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_free\n"); + return -1; + } + mlx_fast_metal_kernel_config_add_output_arg_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_add_output_arg"); + if (mlx_fast_metal_kernel_config_add_output_arg_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_add_output_arg\n"); + return -1; + } + mlx_fast_metal_kernel_config_set_grid_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_set_grid"); + if (mlx_fast_metal_kernel_config_set_grid_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_set_grid\n"); + return -1; + } + mlx_fast_metal_kernel_config_set_thread_group_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_set_thread_group"); + if (mlx_fast_metal_kernel_config_set_thread_group_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_set_thread_group\n"); + return -1; + } + mlx_fast_metal_kernel_config_set_init_value_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_set_init_value"); + if (mlx_fast_metal_kernel_config_set_init_value_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_set_init_value\n"); + return -1; + } + mlx_fast_metal_kernel_config_set_verbose_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_set_verbose"); + if (mlx_fast_metal_kernel_config_set_verbose_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_set_verbose\n"); + return -1; + } + mlx_fast_metal_kernel_config_add_template_arg_dtype_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_add_template_arg_dtype"); + if (mlx_fast_metal_kernel_config_add_template_arg_dtype_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_add_template_arg_dtype\n"); + return -1; + } + mlx_fast_metal_kernel_config_add_template_arg_int_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_add_template_arg_int"); + if (mlx_fast_metal_kernel_config_add_template_arg_int_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_add_template_arg_int\n"); + return -1; + } + mlx_fast_metal_kernel_config_add_template_arg_bool_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_add_template_arg_bool"); + if (mlx_fast_metal_kernel_config_add_template_arg_bool_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_add_template_arg_bool\n"); + return -1; + } + mlx_fast_metal_kernel_new_ptr = dlsym(handle, "mlx_fast_metal_kernel_new"); + if (mlx_fast_metal_kernel_new_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_new\n"); + return -1; + } + mlx_fast_metal_kernel_free_ptr = dlsym(handle, "mlx_fast_metal_kernel_free"); + if (mlx_fast_metal_kernel_free_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_free\n"); + return -1; + } + mlx_fast_metal_kernel_apply_ptr = dlsym(handle, "mlx_fast_metal_kernel_apply"); + if (mlx_fast_metal_kernel_apply_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_apply\n"); + return -1; + } + mlx_fast_rms_norm_ptr = dlsym(handle, "mlx_fast_rms_norm"); + if (mlx_fast_rms_norm_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_rms_norm\n"); + return -1; + } + mlx_fast_rope_ptr = dlsym(handle, "mlx_fast_rope"); + if (mlx_fast_rope_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_rope\n"); + return -1; + } + mlx_fast_rope_dynamic_ptr = dlsym(handle, "mlx_fast_rope_dynamic"); + if (mlx_fast_rope_dynamic_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_rope_dynamic\n"); + return -1; + } + mlx_fast_scaled_dot_product_attention_ptr = dlsym(handle, "mlx_fast_scaled_dot_product_attention"); + if (mlx_fast_scaled_dot_product_attention_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_scaled_dot_product_attention\n"); + return -1; + } + mlx_fft_fft_ptr = dlsym(handle, "mlx_fft_fft"); + if (mlx_fft_fft_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_fft\n"); + return -1; + } + mlx_fft_fft2_ptr = dlsym(handle, "mlx_fft_fft2"); + if (mlx_fft_fft2_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_fft2\n"); + return -1; + } + mlx_fft_fftn_ptr = dlsym(handle, "mlx_fft_fftn"); + if (mlx_fft_fftn_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_fftn\n"); + return -1; + } + mlx_fft_fftshift_ptr = dlsym(handle, "mlx_fft_fftshift"); + if (mlx_fft_fftshift_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_fftshift\n"); + return -1; + } + mlx_fft_ifft_ptr = dlsym(handle, "mlx_fft_ifft"); + if (mlx_fft_ifft_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_ifft\n"); + return -1; + } + mlx_fft_ifft2_ptr = dlsym(handle, "mlx_fft_ifft2"); + if (mlx_fft_ifft2_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_ifft2\n"); + return -1; + } + mlx_fft_ifftn_ptr = dlsym(handle, "mlx_fft_ifftn"); + if (mlx_fft_ifftn_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_ifftn\n"); + return -1; + } + mlx_fft_ifftshift_ptr = dlsym(handle, "mlx_fft_ifftshift"); + if (mlx_fft_ifftshift_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_ifftshift\n"); + return -1; + } + mlx_fft_irfft_ptr = dlsym(handle, "mlx_fft_irfft"); + if (mlx_fft_irfft_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_irfft\n"); + return -1; + } + mlx_fft_irfft2_ptr = dlsym(handle, "mlx_fft_irfft2"); + if (mlx_fft_irfft2_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_irfft2\n"); + return -1; + } + mlx_fft_irfftn_ptr = dlsym(handle, "mlx_fft_irfftn"); + if (mlx_fft_irfftn_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_irfftn\n"); + return -1; + } + mlx_fft_rfft_ptr = dlsym(handle, "mlx_fft_rfft"); + if (mlx_fft_rfft_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_rfft\n"); + return -1; + } + mlx_fft_rfft2_ptr = dlsym(handle, "mlx_fft_rfft2"); + if (mlx_fft_rfft2_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_rfft2\n"); + return -1; + } + mlx_fft_rfftn_ptr = dlsym(handle, "mlx_fft_rfftn"); + if (mlx_fft_rfftn_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_rfftn\n"); + return -1; + } + mlx_load_reader_ptr = dlsym(handle, "mlx_load_reader"); + if (mlx_load_reader_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_load_reader\n"); + return -1; + } + mlx_load_ptr = dlsym(handle, "mlx_load"); + if (mlx_load_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_load\n"); + return -1; + } + mlx_load_safetensors_reader_ptr = dlsym(handle, "mlx_load_safetensors_reader"); + if (mlx_load_safetensors_reader_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_load_safetensors_reader\n"); + return -1; + } + mlx_load_safetensors_ptr = dlsym(handle, "mlx_load_safetensors"); + if (mlx_load_safetensors_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_load_safetensors\n"); + return -1; + } + mlx_save_writer_ptr = dlsym(handle, "mlx_save_writer"); + if (mlx_save_writer_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_save_writer\n"); + return -1; + } + mlx_save_ptr = dlsym(handle, "mlx_save"); + if (mlx_save_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_save\n"); + return -1; + } + mlx_save_safetensors_writer_ptr = dlsym(handle, "mlx_save_safetensors_writer"); + if (mlx_save_safetensors_writer_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_save_safetensors_writer\n"); + return -1; + } + mlx_save_safetensors_ptr = dlsym(handle, "mlx_save_safetensors"); + if (mlx_save_safetensors_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_save_safetensors\n"); + return -1; + } + mlx_io_reader_new_ptr = dlsym(handle, "mlx_io_reader_new"); + if (mlx_io_reader_new_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_io_reader_new\n"); + return -1; + } + mlx_io_reader_descriptor_ptr = dlsym(handle, "mlx_io_reader_descriptor"); + if (mlx_io_reader_descriptor_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_io_reader_descriptor\n"); + return -1; + } + mlx_io_reader_tostring_ptr = dlsym(handle, "mlx_io_reader_tostring"); + if (mlx_io_reader_tostring_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_io_reader_tostring\n"); + return -1; + } + mlx_io_reader_free_ptr = dlsym(handle, "mlx_io_reader_free"); + if (mlx_io_reader_free_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_io_reader_free\n"); + return -1; + } + mlx_io_writer_new_ptr = dlsym(handle, "mlx_io_writer_new"); + if (mlx_io_writer_new_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_io_writer_new\n"); + return -1; + } + mlx_io_writer_descriptor_ptr = dlsym(handle, "mlx_io_writer_descriptor"); + if (mlx_io_writer_descriptor_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_io_writer_descriptor\n"); + return -1; + } + mlx_io_writer_tostring_ptr = dlsym(handle, "mlx_io_writer_tostring"); + if (mlx_io_writer_tostring_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_io_writer_tostring\n"); + return -1; + } + mlx_io_writer_free_ptr = dlsym(handle, "mlx_io_writer_free"); + if (mlx_io_writer_free_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_io_writer_free\n"); + return -1; + } + mlx_linalg_cholesky_ptr = dlsym(handle, "mlx_linalg_cholesky"); + if (mlx_linalg_cholesky_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_cholesky\n"); + return -1; + } + mlx_linalg_cholesky_inv_ptr = dlsym(handle, "mlx_linalg_cholesky_inv"); + if (mlx_linalg_cholesky_inv_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_cholesky_inv\n"); + return -1; + } + mlx_linalg_cross_ptr = dlsym(handle, "mlx_linalg_cross"); + if (mlx_linalg_cross_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_cross\n"); + return -1; + } + mlx_linalg_eig_ptr = dlsym(handle, "mlx_linalg_eig"); + if (mlx_linalg_eig_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_eig\n"); + return -1; + } + mlx_linalg_eigh_ptr = dlsym(handle, "mlx_linalg_eigh"); + if (mlx_linalg_eigh_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_eigh\n"); + return -1; + } + mlx_linalg_eigvals_ptr = dlsym(handle, "mlx_linalg_eigvals"); + if (mlx_linalg_eigvals_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_eigvals\n"); + return -1; + } + mlx_linalg_eigvalsh_ptr = dlsym(handle, "mlx_linalg_eigvalsh"); + if (mlx_linalg_eigvalsh_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_eigvalsh\n"); + return -1; + } + mlx_linalg_inv_ptr = dlsym(handle, "mlx_linalg_inv"); + if (mlx_linalg_inv_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_inv\n"); + return -1; + } + mlx_linalg_lu_ptr = dlsym(handle, "mlx_linalg_lu"); + if (mlx_linalg_lu_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_lu\n"); + return -1; + } + mlx_linalg_lu_factor_ptr = dlsym(handle, "mlx_linalg_lu_factor"); + if (mlx_linalg_lu_factor_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_lu_factor\n"); + return -1; + } + mlx_linalg_norm_ptr = dlsym(handle, "mlx_linalg_norm"); + if (mlx_linalg_norm_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_norm\n"); + return -1; + } + mlx_linalg_norm_matrix_ptr = dlsym(handle, "mlx_linalg_norm_matrix"); + if (mlx_linalg_norm_matrix_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_norm_matrix\n"); + return -1; + } + mlx_linalg_norm_l2_ptr = dlsym(handle, "mlx_linalg_norm_l2"); + if (mlx_linalg_norm_l2_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_norm_l2\n"); + return -1; + } + mlx_linalg_pinv_ptr = dlsym(handle, "mlx_linalg_pinv"); + if (mlx_linalg_pinv_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_pinv\n"); + return -1; + } + mlx_linalg_qr_ptr = dlsym(handle, "mlx_linalg_qr"); + if (mlx_linalg_qr_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_qr\n"); + return -1; + } + mlx_linalg_solve_ptr = dlsym(handle, "mlx_linalg_solve"); + if (mlx_linalg_solve_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_solve\n"); + return -1; + } + mlx_linalg_solve_triangular_ptr = dlsym(handle, "mlx_linalg_solve_triangular"); + if (mlx_linalg_solve_triangular_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_solve_triangular\n"); + return -1; + } + mlx_linalg_svd_ptr = dlsym(handle, "mlx_linalg_svd"); + if (mlx_linalg_svd_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_svd\n"); + return -1; + } + mlx_linalg_tri_inv_ptr = dlsym(handle, "mlx_linalg_tri_inv"); + if (mlx_linalg_tri_inv_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_tri_inv\n"); + return -1; + } + mlx_map_string_to_array_new_ptr = dlsym(handle, "mlx_map_string_to_array_new"); + if (mlx_map_string_to_array_new_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_new\n"); + return -1; + } + mlx_map_string_to_array_set_ptr = dlsym(handle, "mlx_map_string_to_array_set"); + if (mlx_map_string_to_array_set_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_set\n"); + return -1; + } + mlx_map_string_to_array_free_ptr = dlsym(handle, "mlx_map_string_to_array_free"); + if (mlx_map_string_to_array_free_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_free\n"); + return -1; + } + mlx_map_string_to_array_insert_ptr = dlsym(handle, "mlx_map_string_to_array_insert"); + if (mlx_map_string_to_array_insert_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_insert\n"); + return -1; + } + mlx_map_string_to_array_get_ptr = dlsym(handle, "mlx_map_string_to_array_get"); + if (mlx_map_string_to_array_get_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_get\n"); + return -1; + } + mlx_map_string_to_array_iterator_new_ptr = dlsym(handle, "mlx_map_string_to_array_iterator_new"); + if (mlx_map_string_to_array_iterator_new_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_iterator_new\n"); + return -1; + } + mlx_map_string_to_array_iterator_free_ptr = dlsym(handle, "mlx_map_string_to_array_iterator_free"); + if (mlx_map_string_to_array_iterator_free_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_iterator_free\n"); + return -1; + } + mlx_map_string_to_array_iterator_next_ptr = dlsym(handle, "mlx_map_string_to_array_iterator_next"); + if (mlx_map_string_to_array_iterator_next_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_iterator_next\n"); + return -1; + } + mlx_map_string_to_string_new_ptr = dlsym(handle, "mlx_map_string_to_string_new"); + if (mlx_map_string_to_string_new_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_new\n"); + return -1; + } + mlx_map_string_to_string_set_ptr = dlsym(handle, "mlx_map_string_to_string_set"); + if (mlx_map_string_to_string_set_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_set\n"); + return -1; + } + mlx_map_string_to_string_free_ptr = dlsym(handle, "mlx_map_string_to_string_free"); + if (mlx_map_string_to_string_free_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_free\n"); + return -1; + } + mlx_map_string_to_string_insert_ptr = dlsym(handle, "mlx_map_string_to_string_insert"); + if (mlx_map_string_to_string_insert_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_insert\n"); + return -1; + } + mlx_map_string_to_string_get_ptr = dlsym(handle, "mlx_map_string_to_string_get"); + if (mlx_map_string_to_string_get_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_get\n"); + return -1; + } + mlx_map_string_to_string_iterator_new_ptr = dlsym(handle, "mlx_map_string_to_string_iterator_new"); + if (mlx_map_string_to_string_iterator_new_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_iterator_new\n"); + return -1; + } + mlx_map_string_to_string_iterator_free_ptr = dlsym(handle, "mlx_map_string_to_string_iterator_free"); + if (mlx_map_string_to_string_iterator_free_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_iterator_free\n"); + return -1; + } + mlx_map_string_to_string_iterator_next_ptr = dlsym(handle, "mlx_map_string_to_string_iterator_next"); + if (mlx_map_string_to_string_iterator_next_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_iterator_next\n"); + return -1; + } + mlx_clear_cache_ptr = dlsym(handle, "mlx_clear_cache"); + if (mlx_clear_cache_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_clear_cache\n"); + return -1; + } + mlx_get_active_memory_ptr = dlsym(handle, "mlx_get_active_memory"); + if (mlx_get_active_memory_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_get_active_memory\n"); + return -1; + } + mlx_get_cache_memory_ptr = dlsym(handle, "mlx_get_cache_memory"); + if (mlx_get_cache_memory_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_get_cache_memory\n"); + return -1; + } + mlx_get_memory_limit_ptr = dlsym(handle, "mlx_get_memory_limit"); + if (mlx_get_memory_limit_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_get_memory_limit\n"); + return -1; + } + mlx_get_peak_memory_ptr = dlsym(handle, "mlx_get_peak_memory"); + if (mlx_get_peak_memory_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_get_peak_memory\n"); + return -1; + } + mlx_reset_peak_memory_ptr = dlsym(handle, "mlx_reset_peak_memory"); + if (mlx_reset_peak_memory_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_reset_peak_memory\n"); + return -1; + } + mlx_set_cache_limit_ptr = dlsym(handle, "mlx_set_cache_limit"); + if (mlx_set_cache_limit_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_set_cache_limit\n"); + return -1; + } + mlx_set_memory_limit_ptr = dlsym(handle, "mlx_set_memory_limit"); + if (mlx_set_memory_limit_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_set_memory_limit\n"); + return -1; + } + mlx_set_wired_limit_ptr = dlsym(handle, "mlx_set_wired_limit"); + if (mlx_set_wired_limit_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_set_wired_limit\n"); + return -1; + } + mlx_metal_device_info_ptr = dlsym(handle, "mlx_metal_device_info"); + if (mlx_metal_device_info_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_device_info\n"); + return -1; + } + mlx_metal_is_available_ptr = dlsym(handle, "mlx_metal_is_available"); + if (mlx_metal_is_available_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_is_available\n"); + return -1; + } + mlx_metal_start_capture_ptr = dlsym(handle, "mlx_metal_start_capture"); + if (mlx_metal_start_capture_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_start_capture\n"); + return -1; + } + mlx_metal_stop_capture_ptr = dlsym(handle, "mlx_metal_stop_capture"); + if (mlx_metal_stop_capture_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_stop_capture\n"); + return -1; + } + mlx_abs_ptr = dlsym(handle, "mlx_abs"); + if (mlx_abs_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_abs\n"); + return -1; + } + mlx_add_ptr = dlsym(handle, "mlx_add"); + if (mlx_add_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_add\n"); + return -1; + } + mlx_addmm_ptr = dlsym(handle, "mlx_addmm"); + if (mlx_addmm_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_addmm\n"); + return -1; + } + mlx_all_axes_ptr = dlsym(handle, "mlx_all_axes"); + if (mlx_all_axes_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_all_axes\n"); + return -1; + } + mlx_all_axis_ptr = dlsym(handle, "mlx_all_axis"); + if (mlx_all_axis_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_all_axis\n"); + return -1; + } + mlx_all_ptr = dlsym(handle, "mlx_all"); + if (mlx_all_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_all\n"); + return -1; + } + mlx_allclose_ptr = dlsym(handle, "mlx_allclose"); + if (mlx_allclose_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_allclose\n"); + return -1; + } + mlx_any_axes_ptr = dlsym(handle, "mlx_any_axes"); + if (mlx_any_axes_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_any_axes\n"); + return -1; + } + mlx_any_axis_ptr = dlsym(handle, "mlx_any_axis"); + if (mlx_any_axis_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_any_axis\n"); + return -1; + } + mlx_any_ptr = dlsym(handle, "mlx_any"); + if (mlx_any_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_any\n"); + return -1; + } + mlx_arange_ptr = dlsym(handle, "mlx_arange"); + if (mlx_arange_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_arange\n"); + return -1; + } + mlx_arccos_ptr = dlsym(handle, "mlx_arccos"); + if (mlx_arccos_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_arccos\n"); + return -1; + } + mlx_arccosh_ptr = dlsym(handle, "mlx_arccosh"); + if (mlx_arccosh_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_arccosh\n"); + return -1; + } + mlx_arcsin_ptr = dlsym(handle, "mlx_arcsin"); + if (mlx_arcsin_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_arcsin\n"); + return -1; + } + mlx_arcsinh_ptr = dlsym(handle, "mlx_arcsinh"); + if (mlx_arcsinh_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_arcsinh\n"); + return -1; + } + mlx_arctan_ptr = dlsym(handle, "mlx_arctan"); + if (mlx_arctan_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_arctan\n"); + return -1; + } + mlx_arctan2_ptr = dlsym(handle, "mlx_arctan2"); + if (mlx_arctan2_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_arctan2\n"); + return -1; + } + mlx_arctanh_ptr = dlsym(handle, "mlx_arctanh"); + if (mlx_arctanh_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_arctanh\n"); + return -1; + } + mlx_argmax_axis_ptr = dlsym(handle, "mlx_argmax_axis"); + if (mlx_argmax_axis_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_argmax_axis\n"); + return -1; + } + mlx_argmax_ptr = dlsym(handle, "mlx_argmax"); + if (mlx_argmax_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_argmax\n"); + return -1; + } + mlx_argmin_axis_ptr = dlsym(handle, "mlx_argmin_axis"); + if (mlx_argmin_axis_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_argmin_axis\n"); + return -1; + } + mlx_argmin_ptr = dlsym(handle, "mlx_argmin"); + if (mlx_argmin_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_argmin\n"); + return -1; + } + mlx_argpartition_axis_ptr = dlsym(handle, "mlx_argpartition_axis"); + if (mlx_argpartition_axis_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_argpartition_axis\n"); + return -1; + } + mlx_argpartition_ptr = dlsym(handle, "mlx_argpartition"); + if (mlx_argpartition_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_argpartition\n"); + return -1; + } + mlx_argsort_axis_ptr = dlsym(handle, "mlx_argsort_axis"); + if (mlx_argsort_axis_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_argsort_axis\n"); + return -1; + } + mlx_argsort_ptr = dlsym(handle, "mlx_argsort"); + if (mlx_argsort_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_argsort\n"); + return -1; + } + mlx_array_equal_ptr = dlsym(handle, "mlx_array_equal"); + if (mlx_array_equal_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_array_equal\n"); + return -1; + } + mlx_as_strided_ptr = dlsym(handle, "mlx_as_strided"); + if (mlx_as_strided_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_as_strided\n"); + return -1; + } + mlx_astype_ptr = dlsym(handle, "mlx_astype"); + if (mlx_astype_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_astype\n"); + return -1; + } + mlx_atleast_1d_ptr = dlsym(handle, "mlx_atleast_1d"); + if (mlx_atleast_1d_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_atleast_1d\n"); + return -1; + } + mlx_atleast_2d_ptr = dlsym(handle, "mlx_atleast_2d"); + if (mlx_atleast_2d_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_atleast_2d\n"); + return -1; + } + mlx_atleast_3d_ptr = dlsym(handle, "mlx_atleast_3d"); + if (mlx_atleast_3d_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_atleast_3d\n"); + return -1; + } + mlx_bitwise_and_ptr = dlsym(handle, "mlx_bitwise_and"); + if (mlx_bitwise_and_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_and\n"); + return -1; + } + mlx_bitwise_invert_ptr = dlsym(handle, "mlx_bitwise_invert"); + if (mlx_bitwise_invert_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_invert\n"); + return -1; + } + mlx_bitwise_or_ptr = dlsym(handle, "mlx_bitwise_or"); + if (mlx_bitwise_or_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_or\n"); + return -1; + } + mlx_bitwise_xor_ptr = dlsym(handle, "mlx_bitwise_xor"); + if (mlx_bitwise_xor_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_xor\n"); + return -1; + } + mlx_block_masked_mm_ptr = dlsym(handle, "mlx_block_masked_mm"); + if (mlx_block_masked_mm_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_block_masked_mm\n"); + return -1; + } + mlx_broadcast_arrays_ptr = dlsym(handle, "mlx_broadcast_arrays"); + if (mlx_broadcast_arrays_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_broadcast_arrays\n"); + return -1; + } + mlx_broadcast_to_ptr = dlsym(handle, "mlx_broadcast_to"); + if (mlx_broadcast_to_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_broadcast_to\n"); + return -1; + } + mlx_ceil_ptr = dlsym(handle, "mlx_ceil"); + if (mlx_ceil_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_ceil\n"); + return -1; + } + mlx_clip_ptr = dlsym(handle, "mlx_clip"); + if (mlx_clip_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_clip\n"); + return -1; + } + mlx_concatenate_axis_ptr = dlsym(handle, "mlx_concatenate_axis"); + if (mlx_concatenate_axis_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_concatenate_axis\n"); + return -1; + } + mlx_concatenate_ptr = dlsym(handle, "mlx_concatenate"); + if (mlx_concatenate_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_concatenate\n"); + return -1; + } + mlx_conjugate_ptr = dlsym(handle, "mlx_conjugate"); + if (mlx_conjugate_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_conjugate\n"); + return -1; + } + mlx_contiguous_ptr = dlsym(handle, "mlx_contiguous"); + if (mlx_contiguous_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_contiguous\n"); + return -1; + } + mlx_conv1d_ptr = dlsym(handle, "mlx_conv1d"); + if (mlx_conv1d_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_conv1d\n"); + return -1; + } + mlx_conv2d_ptr = dlsym(handle, "mlx_conv2d"); + if (mlx_conv2d_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_conv2d\n"); + return -1; + } + mlx_conv3d_ptr = dlsym(handle, "mlx_conv3d"); + if (mlx_conv3d_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_conv3d\n"); + return -1; + } + mlx_conv_general_ptr = dlsym(handle, "mlx_conv_general"); + if (mlx_conv_general_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_conv_general\n"); + return -1; + } + mlx_conv_transpose1d_ptr = dlsym(handle, "mlx_conv_transpose1d"); + if (mlx_conv_transpose1d_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_conv_transpose1d\n"); + return -1; + } + mlx_conv_transpose2d_ptr = dlsym(handle, "mlx_conv_transpose2d"); + if (mlx_conv_transpose2d_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_conv_transpose2d\n"); + return -1; + } + mlx_conv_transpose3d_ptr = dlsym(handle, "mlx_conv_transpose3d"); + if (mlx_conv_transpose3d_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_conv_transpose3d\n"); + return -1; + } + mlx_copy_ptr = dlsym(handle, "mlx_copy"); + if (mlx_copy_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_copy\n"); + return -1; + } + mlx_cos_ptr = dlsym(handle, "mlx_cos"); + if (mlx_cos_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_cos\n"); + return -1; + } + mlx_cosh_ptr = dlsym(handle, "mlx_cosh"); + if (mlx_cosh_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_cosh\n"); + return -1; + } + mlx_cummax_ptr = dlsym(handle, "mlx_cummax"); + if (mlx_cummax_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_cummax\n"); + return -1; + } + mlx_cummin_ptr = dlsym(handle, "mlx_cummin"); + if (mlx_cummin_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_cummin\n"); + return -1; + } + mlx_cumprod_ptr = dlsym(handle, "mlx_cumprod"); + if (mlx_cumprod_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_cumprod\n"); + return -1; + } + mlx_cumsum_ptr = dlsym(handle, "mlx_cumsum"); + if (mlx_cumsum_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_cumsum\n"); + return -1; + } + mlx_degrees_ptr = dlsym(handle, "mlx_degrees"); + if (mlx_degrees_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_degrees\n"); + return -1; + } + mlx_depends_ptr = dlsym(handle, "mlx_depends"); + if (mlx_depends_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_depends\n"); + return -1; + } + mlx_dequantize_ptr = dlsym(handle, "mlx_dequantize"); + if (mlx_dequantize_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_dequantize\n"); + return -1; + } + mlx_diag_ptr = dlsym(handle, "mlx_diag"); + if (mlx_diag_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_diag\n"); + return -1; + } + mlx_diagonal_ptr = dlsym(handle, "mlx_diagonal"); + if (mlx_diagonal_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_diagonal\n"); + return -1; + } + mlx_divide_ptr = dlsym(handle, "mlx_divide"); + if (mlx_divide_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_divide\n"); + return -1; + } + mlx_divmod_ptr = dlsym(handle, "mlx_divmod"); + if (mlx_divmod_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_divmod\n"); + return -1; + } + mlx_einsum_ptr = dlsym(handle, "mlx_einsum"); + if (mlx_einsum_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_einsum\n"); + return -1; + } + mlx_equal_ptr = dlsym(handle, "mlx_equal"); + if (mlx_equal_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_equal\n"); + return -1; + } + mlx_erf_ptr = dlsym(handle, "mlx_erf"); + if (mlx_erf_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_erf\n"); + return -1; + } + mlx_erfinv_ptr = dlsym(handle, "mlx_erfinv"); + if (mlx_erfinv_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_erfinv\n"); + return -1; + } + mlx_exp_ptr = dlsym(handle, "mlx_exp"); + if (mlx_exp_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_exp\n"); + return -1; + } + mlx_expand_dims_axes_ptr = dlsym(handle, "mlx_expand_dims_axes"); + if (mlx_expand_dims_axes_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_expand_dims_axes\n"); + return -1; + } + mlx_expand_dims_ptr = dlsym(handle, "mlx_expand_dims"); + if (mlx_expand_dims_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_expand_dims\n"); + return -1; + } + mlx_expm1_ptr = dlsym(handle, "mlx_expm1"); + if (mlx_expm1_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_expm1\n"); + return -1; + } + mlx_eye_ptr = dlsym(handle, "mlx_eye"); + if (mlx_eye_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_eye\n"); + return -1; + } + mlx_flatten_ptr = dlsym(handle, "mlx_flatten"); + if (mlx_flatten_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_flatten\n"); + return -1; + } + mlx_floor_ptr = dlsym(handle, "mlx_floor"); + if (mlx_floor_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_floor\n"); + return -1; + } + mlx_floor_divide_ptr = dlsym(handle, "mlx_floor_divide"); + if (mlx_floor_divide_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_floor_divide\n"); + return -1; + } + mlx_from_fp8_ptr = dlsym(handle, "mlx_from_fp8"); + if (mlx_from_fp8_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_from_fp8\n"); + return -1; + } + mlx_full_ptr = dlsym(handle, "mlx_full"); + if (mlx_full_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_full\n"); + return -1; + } + mlx_full_like_ptr = dlsym(handle, "mlx_full_like"); + if (mlx_full_like_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_full_like\n"); + return -1; + } + mlx_gather_ptr = dlsym(handle, "mlx_gather"); + if (mlx_gather_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_gather\n"); + return -1; + } + mlx_gather_single_ptr = dlsym(handle, "mlx_gather_single"); + if (mlx_gather_single_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_gather_single\n"); + return -1; + } + mlx_gather_mm_ptr = dlsym(handle, "mlx_gather_mm"); + if (mlx_gather_mm_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_gather_mm\n"); + return -1; + } + mlx_gather_qmm_ptr = dlsym(handle, "mlx_gather_qmm"); + if (mlx_gather_qmm_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_gather_qmm\n"); + return -1; + } + mlx_greater_ptr = dlsym(handle, "mlx_greater"); + if (mlx_greater_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_greater\n"); + return -1; + } + mlx_greater_equal_ptr = dlsym(handle, "mlx_greater_equal"); + if (mlx_greater_equal_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_greater_equal\n"); + return -1; + } + mlx_hadamard_transform_ptr = dlsym(handle, "mlx_hadamard_transform"); + if (mlx_hadamard_transform_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_hadamard_transform\n"); + return -1; + } + mlx_identity_ptr = dlsym(handle, "mlx_identity"); + if (mlx_identity_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_identity\n"); + return -1; + } + mlx_imag_ptr = dlsym(handle, "mlx_imag"); + if (mlx_imag_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_imag\n"); + return -1; + } + mlx_inner_ptr = dlsym(handle, "mlx_inner"); + if (mlx_inner_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_inner\n"); + return -1; + } + mlx_isclose_ptr = dlsym(handle, "mlx_isclose"); + if (mlx_isclose_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_isclose\n"); + return -1; + } + mlx_isfinite_ptr = dlsym(handle, "mlx_isfinite"); + if (mlx_isfinite_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_isfinite\n"); + return -1; + } + mlx_isinf_ptr = dlsym(handle, "mlx_isinf"); + if (mlx_isinf_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_isinf\n"); + return -1; + } + mlx_isnan_ptr = dlsym(handle, "mlx_isnan"); + if (mlx_isnan_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_isnan\n"); + return -1; + } + mlx_isneginf_ptr = dlsym(handle, "mlx_isneginf"); + if (mlx_isneginf_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_isneginf\n"); + return -1; + } + mlx_isposinf_ptr = dlsym(handle, "mlx_isposinf"); + if (mlx_isposinf_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_isposinf\n"); + return -1; + } + mlx_kron_ptr = dlsym(handle, "mlx_kron"); + if (mlx_kron_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_kron\n"); + return -1; + } + mlx_left_shift_ptr = dlsym(handle, "mlx_left_shift"); + if (mlx_left_shift_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_left_shift\n"); + return -1; + } + mlx_less_ptr = dlsym(handle, "mlx_less"); + if (mlx_less_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_less\n"); + return -1; + } + mlx_less_equal_ptr = dlsym(handle, "mlx_less_equal"); + if (mlx_less_equal_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_less_equal\n"); + return -1; + } + mlx_linspace_ptr = dlsym(handle, "mlx_linspace"); + if (mlx_linspace_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_linspace\n"); + return -1; + } + mlx_log_ptr = dlsym(handle, "mlx_log"); + if (mlx_log_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_log\n"); + return -1; + } + mlx_log10_ptr = dlsym(handle, "mlx_log10"); + if (mlx_log10_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_log10\n"); + return -1; + } + mlx_log1p_ptr = dlsym(handle, "mlx_log1p"); + if (mlx_log1p_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_log1p\n"); + return -1; + } + mlx_log2_ptr = dlsym(handle, "mlx_log2"); + if (mlx_log2_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_log2\n"); + return -1; + } + mlx_logaddexp_ptr = dlsym(handle, "mlx_logaddexp"); + if (mlx_logaddexp_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_logaddexp\n"); + return -1; + } + mlx_logcumsumexp_ptr = dlsym(handle, "mlx_logcumsumexp"); + if (mlx_logcumsumexp_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_logcumsumexp\n"); + return -1; + } + mlx_logical_and_ptr = dlsym(handle, "mlx_logical_and"); + if (mlx_logical_and_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_logical_and\n"); + return -1; + } + mlx_logical_not_ptr = dlsym(handle, "mlx_logical_not"); + if (mlx_logical_not_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_logical_not\n"); + return -1; + } + mlx_logical_or_ptr = dlsym(handle, "mlx_logical_or"); + if (mlx_logical_or_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_logical_or\n"); + return -1; + } + mlx_logsumexp_axes_ptr = dlsym(handle, "mlx_logsumexp_axes"); + if (mlx_logsumexp_axes_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_logsumexp_axes\n"); + return -1; + } + mlx_logsumexp_axis_ptr = dlsym(handle, "mlx_logsumexp_axis"); + if (mlx_logsumexp_axis_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_logsumexp_axis\n"); + return -1; + } + mlx_logsumexp_ptr = dlsym(handle, "mlx_logsumexp"); + if (mlx_logsumexp_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_logsumexp\n"); + return -1; + } + mlx_masked_scatter_ptr = dlsym(handle, "mlx_masked_scatter"); + if (mlx_masked_scatter_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_masked_scatter\n"); + return -1; + } + mlx_matmul_ptr = dlsym(handle, "mlx_matmul"); + if (mlx_matmul_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_matmul\n"); + return -1; + } + mlx_max_axes_ptr = dlsym(handle, "mlx_max_axes"); + if (mlx_max_axes_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_max_axes\n"); + return -1; + } + mlx_max_axis_ptr = dlsym(handle, "mlx_max_axis"); + if (mlx_max_axis_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_max_axis\n"); + return -1; + } + mlx_max_ptr = dlsym(handle, "mlx_max"); + if (mlx_max_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_max\n"); + return -1; + } + mlx_maximum_ptr = dlsym(handle, "mlx_maximum"); + if (mlx_maximum_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_maximum\n"); + return -1; + } + mlx_mean_axes_ptr = dlsym(handle, "mlx_mean_axes"); + if (mlx_mean_axes_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_mean_axes\n"); + return -1; + } + mlx_mean_axis_ptr = dlsym(handle, "mlx_mean_axis"); + if (mlx_mean_axis_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_mean_axis\n"); + return -1; + } + mlx_mean_ptr = dlsym(handle, "mlx_mean"); + if (mlx_mean_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_mean\n"); + return -1; + } + mlx_median_ptr = dlsym(handle, "mlx_median"); + if (mlx_median_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_median\n"); + return -1; + } + mlx_meshgrid_ptr = dlsym(handle, "mlx_meshgrid"); + if (mlx_meshgrid_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_meshgrid\n"); + return -1; + } + mlx_min_axes_ptr = dlsym(handle, "mlx_min_axes"); + if (mlx_min_axes_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_min_axes\n"); + return -1; + } + mlx_min_axis_ptr = dlsym(handle, "mlx_min_axis"); + if (mlx_min_axis_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_min_axis\n"); + return -1; + } + mlx_min_ptr = dlsym(handle, "mlx_min"); + if (mlx_min_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_min\n"); + return -1; + } + mlx_minimum_ptr = dlsym(handle, "mlx_minimum"); + if (mlx_minimum_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_minimum\n"); + return -1; + } + mlx_moveaxis_ptr = dlsym(handle, "mlx_moveaxis"); + if (mlx_moveaxis_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_moveaxis\n"); + return -1; + } + mlx_multiply_ptr = dlsym(handle, "mlx_multiply"); + if (mlx_multiply_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_multiply\n"); + return -1; + } + mlx_nan_to_num_ptr = dlsym(handle, "mlx_nan_to_num"); + if (mlx_nan_to_num_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_nan_to_num\n"); + return -1; + } + mlx_negative_ptr = dlsym(handle, "mlx_negative"); + if (mlx_negative_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_negative\n"); + return -1; + } + mlx_not_equal_ptr = dlsym(handle, "mlx_not_equal"); + if (mlx_not_equal_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_not_equal\n"); + return -1; + } + mlx_number_of_elements_ptr = dlsym(handle, "mlx_number_of_elements"); + if (mlx_number_of_elements_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_number_of_elements\n"); + return -1; + } + mlx_ones_ptr = dlsym(handle, "mlx_ones"); + if (mlx_ones_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_ones\n"); + return -1; + } + mlx_ones_like_ptr = dlsym(handle, "mlx_ones_like"); + if (mlx_ones_like_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_ones_like\n"); + return -1; + } + mlx_outer_ptr = dlsym(handle, "mlx_outer"); + if (mlx_outer_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_outer\n"); + return -1; + } + mlx_pad_ptr = dlsym(handle, "mlx_pad"); + if (mlx_pad_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_pad\n"); + return -1; + } + mlx_pad_symmetric_ptr = dlsym(handle, "mlx_pad_symmetric"); + if (mlx_pad_symmetric_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_pad_symmetric\n"); + return -1; + } + mlx_partition_axis_ptr = dlsym(handle, "mlx_partition_axis"); + if (mlx_partition_axis_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_partition_axis\n"); + return -1; + } + mlx_partition_ptr = dlsym(handle, "mlx_partition"); + if (mlx_partition_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_partition\n"); + return -1; + } + mlx_power_ptr = dlsym(handle, "mlx_power"); + if (mlx_power_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_power\n"); + return -1; + } + mlx_prod_axes_ptr = dlsym(handle, "mlx_prod_axes"); + if (mlx_prod_axes_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_prod_axes\n"); + return -1; + } + mlx_prod_axis_ptr = dlsym(handle, "mlx_prod_axis"); + if (mlx_prod_axis_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_prod_axis\n"); + return -1; + } + mlx_prod_ptr = dlsym(handle, "mlx_prod"); + if (mlx_prod_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_prod\n"); + return -1; + } + mlx_put_along_axis_ptr = dlsym(handle, "mlx_put_along_axis"); + if (mlx_put_along_axis_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_put_along_axis\n"); + return -1; + } + mlx_qqmm_ptr = dlsym(handle, "mlx_qqmm"); + if (mlx_qqmm_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_qqmm\n"); + return -1; + } + mlx_quantize_ptr = dlsym(handle, "mlx_quantize"); + if (mlx_quantize_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_quantize\n"); + return -1; + } + mlx_quantized_matmul_ptr = dlsym(handle, "mlx_quantized_matmul"); + if (mlx_quantized_matmul_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_quantized_matmul\n"); + return -1; + } + mlx_radians_ptr = dlsym(handle, "mlx_radians"); + if (mlx_radians_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_radians\n"); + return -1; + } + mlx_real_ptr = dlsym(handle, "mlx_real"); + if (mlx_real_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_real\n"); + return -1; + } + mlx_reciprocal_ptr = dlsym(handle, "mlx_reciprocal"); + if (mlx_reciprocal_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_reciprocal\n"); + return -1; + } + mlx_remainder_ptr = dlsym(handle, "mlx_remainder"); + if (mlx_remainder_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_remainder\n"); + return -1; + } + mlx_repeat_axis_ptr = dlsym(handle, "mlx_repeat_axis"); + if (mlx_repeat_axis_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_repeat_axis\n"); + return -1; + } + mlx_repeat_ptr = dlsym(handle, "mlx_repeat"); + if (mlx_repeat_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_repeat\n"); + return -1; + } + mlx_reshape_ptr = dlsym(handle, "mlx_reshape"); + if (mlx_reshape_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_reshape\n"); + return -1; + } + mlx_right_shift_ptr = dlsym(handle, "mlx_right_shift"); + if (mlx_right_shift_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_right_shift\n"); + return -1; + } + mlx_roll_axis_ptr = dlsym(handle, "mlx_roll_axis"); + if (mlx_roll_axis_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_roll_axis\n"); + return -1; + } + mlx_roll_axes_ptr = dlsym(handle, "mlx_roll_axes"); + if (mlx_roll_axes_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_roll_axes\n"); + return -1; + } + mlx_roll_ptr = dlsym(handle, "mlx_roll"); + if (mlx_roll_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_roll\n"); + return -1; + } + mlx_round_ptr = dlsym(handle, "mlx_round"); + if (mlx_round_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_round\n"); + return -1; + } + mlx_rsqrt_ptr = dlsym(handle, "mlx_rsqrt"); + if (mlx_rsqrt_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_rsqrt\n"); + return -1; + } + mlx_scatter_ptr = dlsym(handle, "mlx_scatter"); + if (mlx_scatter_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter\n"); + return -1; + } + mlx_scatter_single_ptr = dlsym(handle, "mlx_scatter_single"); + if (mlx_scatter_single_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_single\n"); + return -1; + } + mlx_scatter_add_ptr = dlsym(handle, "mlx_scatter_add"); + if (mlx_scatter_add_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_add\n"); + return -1; + } + mlx_scatter_add_single_ptr = dlsym(handle, "mlx_scatter_add_single"); + if (mlx_scatter_add_single_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_add_single\n"); + return -1; + } + mlx_scatter_add_axis_ptr = dlsym(handle, "mlx_scatter_add_axis"); + if (mlx_scatter_add_axis_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_add_axis\n"); + return -1; + } + mlx_scatter_max_ptr = dlsym(handle, "mlx_scatter_max"); + if (mlx_scatter_max_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_max\n"); + return -1; + } + mlx_scatter_max_single_ptr = dlsym(handle, "mlx_scatter_max_single"); + if (mlx_scatter_max_single_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_max_single\n"); + return -1; + } + mlx_scatter_min_ptr = dlsym(handle, "mlx_scatter_min"); + if (mlx_scatter_min_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_min\n"); + return -1; + } + mlx_scatter_min_single_ptr = dlsym(handle, "mlx_scatter_min_single"); + if (mlx_scatter_min_single_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_min_single\n"); + return -1; + } + mlx_scatter_prod_ptr = dlsym(handle, "mlx_scatter_prod"); + if (mlx_scatter_prod_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_prod\n"); + return -1; + } + mlx_scatter_prod_single_ptr = dlsym(handle, "mlx_scatter_prod_single"); + if (mlx_scatter_prod_single_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_prod_single\n"); + return -1; + } + mlx_segmented_mm_ptr = dlsym(handle, "mlx_segmented_mm"); + if (mlx_segmented_mm_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_segmented_mm\n"); + return -1; + } + mlx_sigmoid_ptr = dlsym(handle, "mlx_sigmoid"); + if (mlx_sigmoid_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_sigmoid\n"); + return -1; + } + mlx_sign_ptr = dlsym(handle, "mlx_sign"); + if (mlx_sign_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_sign\n"); + return -1; + } + mlx_sin_ptr = dlsym(handle, "mlx_sin"); + if (mlx_sin_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_sin\n"); + return -1; + } + mlx_sinh_ptr = dlsym(handle, "mlx_sinh"); + if (mlx_sinh_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_sinh\n"); + return -1; + } + mlx_slice_ptr = dlsym(handle, "mlx_slice"); + if (mlx_slice_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_slice\n"); + return -1; + } + mlx_slice_dynamic_ptr = dlsym(handle, "mlx_slice_dynamic"); + if (mlx_slice_dynamic_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_slice_dynamic\n"); + return -1; + } + mlx_slice_update_ptr = dlsym(handle, "mlx_slice_update"); + if (mlx_slice_update_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_slice_update\n"); + return -1; + } + mlx_slice_update_dynamic_ptr = dlsym(handle, "mlx_slice_update_dynamic"); + if (mlx_slice_update_dynamic_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_slice_update_dynamic\n"); + return -1; + } + mlx_softmax_axes_ptr = dlsym(handle, "mlx_softmax_axes"); + if (mlx_softmax_axes_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_softmax_axes\n"); + return -1; + } + mlx_softmax_axis_ptr = dlsym(handle, "mlx_softmax_axis"); + if (mlx_softmax_axis_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_softmax_axis\n"); + return -1; + } + mlx_softmax_ptr = dlsym(handle, "mlx_softmax"); + if (mlx_softmax_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_softmax\n"); + return -1; + } + mlx_sort_axis_ptr = dlsym(handle, "mlx_sort_axis"); + if (mlx_sort_axis_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_sort_axis\n"); + return -1; + } + mlx_sort_ptr = dlsym(handle, "mlx_sort"); + if (mlx_sort_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_sort\n"); + return -1; + } + mlx_split_ptr = dlsym(handle, "mlx_split"); + if (mlx_split_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_split\n"); + return -1; + } + mlx_split_sections_ptr = dlsym(handle, "mlx_split_sections"); + if (mlx_split_sections_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_split_sections\n"); + return -1; + } + mlx_sqrt_ptr = dlsym(handle, "mlx_sqrt"); + if (mlx_sqrt_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_sqrt\n"); + return -1; + } + mlx_square_ptr = dlsym(handle, "mlx_square"); + if (mlx_square_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_square\n"); + return -1; + } + mlx_squeeze_axes_ptr = dlsym(handle, "mlx_squeeze_axes"); + if (mlx_squeeze_axes_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_squeeze_axes\n"); + return -1; + } + mlx_squeeze_axis_ptr = dlsym(handle, "mlx_squeeze_axis"); + if (mlx_squeeze_axis_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_squeeze_axis\n"); + return -1; + } + mlx_squeeze_ptr = dlsym(handle, "mlx_squeeze"); + if (mlx_squeeze_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_squeeze\n"); + return -1; + } + mlx_stack_axis_ptr = dlsym(handle, "mlx_stack_axis"); + if (mlx_stack_axis_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_stack_axis\n"); + return -1; + } + mlx_stack_ptr = dlsym(handle, "mlx_stack"); + if (mlx_stack_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_stack\n"); + return -1; + } + mlx_std_axes_ptr = dlsym(handle, "mlx_std_axes"); + if (mlx_std_axes_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_std_axes\n"); + return -1; + } + mlx_std_axis_ptr = dlsym(handle, "mlx_std_axis"); + if (mlx_std_axis_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_std_axis\n"); + return -1; + } + mlx_std_ptr = dlsym(handle, "mlx_std"); + if (mlx_std_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_std\n"); + return -1; + } + mlx_stop_gradient_ptr = dlsym(handle, "mlx_stop_gradient"); + if (mlx_stop_gradient_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_stop_gradient\n"); + return -1; + } + mlx_subtract_ptr = dlsym(handle, "mlx_subtract"); + if (mlx_subtract_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_subtract\n"); + return -1; + } + mlx_sum_axes_ptr = dlsym(handle, "mlx_sum_axes"); + if (mlx_sum_axes_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_sum_axes\n"); + return -1; + } + mlx_sum_axis_ptr = dlsym(handle, "mlx_sum_axis"); + if (mlx_sum_axis_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_sum_axis\n"); + return -1; + } + mlx_sum_ptr = dlsym(handle, "mlx_sum"); + if (mlx_sum_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_sum\n"); + return -1; + } + mlx_swapaxes_ptr = dlsym(handle, "mlx_swapaxes"); + if (mlx_swapaxes_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_swapaxes\n"); + return -1; + } + mlx_take_axis_ptr = dlsym(handle, "mlx_take_axis"); + if (mlx_take_axis_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_take_axis\n"); + return -1; + } + mlx_take_ptr = dlsym(handle, "mlx_take"); + if (mlx_take_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_take\n"); + return -1; + } + mlx_take_along_axis_ptr = dlsym(handle, "mlx_take_along_axis"); + if (mlx_take_along_axis_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_take_along_axis\n"); + return -1; + } + mlx_tan_ptr = dlsym(handle, "mlx_tan"); + if (mlx_tan_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_tan\n"); + return -1; + } + mlx_tanh_ptr = dlsym(handle, "mlx_tanh"); + if (mlx_tanh_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_tanh\n"); + return -1; + } + mlx_tensordot_ptr = dlsym(handle, "mlx_tensordot"); + if (mlx_tensordot_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_tensordot\n"); + return -1; + } + mlx_tensordot_axis_ptr = dlsym(handle, "mlx_tensordot_axis"); + if (mlx_tensordot_axis_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_tensordot_axis\n"); + return -1; + } + mlx_tile_ptr = dlsym(handle, "mlx_tile"); + if (mlx_tile_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_tile\n"); + return -1; + } + mlx_to_fp8_ptr = dlsym(handle, "mlx_to_fp8"); + if (mlx_to_fp8_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_to_fp8\n"); + return -1; + } + mlx_topk_axis_ptr = dlsym(handle, "mlx_topk_axis"); + if (mlx_topk_axis_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_topk_axis\n"); + return -1; + } + mlx_topk_ptr = dlsym(handle, "mlx_topk"); + if (mlx_topk_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_topk\n"); + return -1; + } + mlx_trace_ptr = dlsym(handle, "mlx_trace"); + if (mlx_trace_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_trace\n"); + return -1; + } + mlx_transpose_axes_ptr = dlsym(handle, "mlx_transpose_axes"); + if (mlx_transpose_axes_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_transpose_axes\n"); + return -1; + } + mlx_transpose_ptr = dlsym(handle, "mlx_transpose"); + if (mlx_transpose_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_transpose\n"); + return -1; + } + mlx_tri_ptr = dlsym(handle, "mlx_tri"); + if (mlx_tri_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_tri\n"); + return -1; + } + mlx_tril_ptr = dlsym(handle, "mlx_tril"); + if (mlx_tril_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_tril\n"); + return -1; + } + mlx_triu_ptr = dlsym(handle, "mlx_triu"); + if (mlx_triu_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_triu\n"); + return -1; + } + mlx_unflatten_ptr = dlsym(handle, "mlx_unflatten"); + if (mlx_unflatten_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_unflatten\n"); + return -1; + } + mlx_var_axes_ptr = dlsym(handle, "mlx_var_axes"); + if (mlx_var_axes_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_var_axes\n"); + return -1; + } + mlx_var_axis_ptr = dlsym(handle, "mlx_var_axis"); + if (mlx_var_axis_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_var_axis\n"); + return -1; + } + mlx_var_ptr = dlsym(handle, "mlx_var"); + if (mlx_var_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_var\n"); + return -1; + } + mlx_view_ptr = dlsym(handle, "mlx_view"); + if (mlx_view_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_view\n"); + return -1; + } + mlx_where_ptr = dlsym(handle, "mlx_where"); + if (mlx_where_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_where\n"); + return -1; + } + mlx_zeros_ptr = dlsym(handle, "mlx_zeros"); + if (mlx_zeros_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_zeros\n"); + return -1; + } + mlx_zeros_like_ptr = dlsym(handle, "mlx_zeros_like"); + if (mlx_zeros_like_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_zeros_like\n"); + return -1; + } + mlx_random_bernoulli_ptr = dlsym(handle, "mlx_random_bernoulli"); + if (mlx_random_bernoulli_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_random_bernoulli\n"); + return -1; + } + mlx_random_bits_ptr = dlsym(handle, "mlx_random_bits"); + if (mlx_random_bits_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_random_bits\n"); + return -1; + } + mlx_random_categorical_shape_ptr = dlsym(handle, "mlx_random_categorical_shape"); + if (mlx_random_categorical_shape_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_random_categorical_shape\n"); + return -1; + } + mlx_random_categorical_num_samples_ptr = dlsym(handle, "mlx_random_categorical_num_samples"); + if (mlx_random_categorical_num_samples_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_random_categorical_num_samples\n"); + return -1; + } + mlx_random_categorical_ptr = dlsym(handle, "mlx_random_categorical"); + if (mlx_random_categorical_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_random_categorical\n"); + return -1; + } + mlx_random_gumbel_ptr = dlsym(handle, "mlx_random_gumbel"); + if (mlx_random_gumbel_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_random_gumbel\n"); + return -1; + } + mlx_random_key_ptr = dlsym(handle, "mlx_random_key"); + if (mlx_random_key_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_random_key\n"); + return -1; + } + mlx_random_laplace_ptr = dlsym(handle, "mlx_random_laplace"); + if (mlx_random_laplace_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_random_laplace\n"); + return -1; + } + mlx_random_multivariate_normal_ptr = dlsym(handle, "mlx_random_multivariate_normal"); + if (mlx_random_multivariate_normal_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_random_multivariate_normal\n"); + return -1; + } + mlx_random_normal_broadcast_ptr = dlsym(handle, "mlx_random_normal_broadcast"); + if (mlx_random_normal_broadcast_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_random_normal_broadcast\n"); + return -1; + } + mlx_random_normal_ptr = dlsym(handle, "mlx_random_normal"); + if (mlx_random_normal_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_random_normal\n"); + return -1; + } + mlx_random_permutation_ptr = dlsym(handle, "mlx_random_permutation"); + if (mlx_random_permutation_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_random_permutation\n"); + return -1; + } + mlx_random_permutation_arange_ptr = dlsym(handle, "mlx_random_permutation_arange"); + if (mlx_random_permutation_arange_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_random_permutation_arange\n"); + return -1; + } + mlx_random_randint_ptr = dlsym(handle, "mlx_random_randint"); + if (mlx_random_randint_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_random_randint\n"); + return -1; + } + mlx_random_seed_ptr = dlsym(handle, "mlx_random_seed"); + if (mlx_random_seed_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_random_seed\n"); + return -1; + } + mlx_random_split_num_ptr = dlsym(handle, "mlx_random_split_num"); + if (mlx_random_split_num_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_random_split_num\n"); + return -1; + } + mlx_random_split_ptr = dlsym(handle, "mlx_random_split"); + if (mlx_random_split_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_random_split\n"); + return -1; + } + mlx_random_truncated_normal_ptr = dlsym(handle, "mlx_random_truncated_normal"); + if (mlx_random_truncated_normal_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_random_truncated_normal\n"); + return -1; + } + mlx_random_uniform_ptr = dlsym(handle, "mlx_random_uniform"); + if (mlx_random_uniform_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_random_uniform\n"); + return -1; + } + mlx_stream_new_ptr = dlsym(handle, "mlx_stream_new"); + if (mlx_stream_new_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_new\n"); + return -1; + } + mlx_stream_new_device_ptr = dlsym(handle, "mlx_stream_new_device"); + if (mlx_stream_new_device_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_new_device\n"); + return -1; + } + mlx_stream_set_ptr = dlsym(handle, "mlx_stream_set"); + if (mlx_stream_set_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_set\n"); + return -1; + } + mlx_stream_free_ptr = dlsym(handle, "mlx_stream_free"); + if (mlx_stream_free_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_free\n"); + return -1; + } + mlx_stream_tostring_ptr = dlsym(handle, "mlx_stream_tostring"); + if (mlx_stream_tostring_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_tostring\n"); + return -1; + } + mlx_stream_equal_ptr = dlsym(handle, "mlx_stream_equal"); + if (mlx_stream_equal_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_equal\n"); + return -1; + } + mlx_stream_get_device_ptr = dlsym(handle, "mlx_stream_get_device"); + if (mlx_stream_get_device_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_get_device\n"); + return -1; + } + mlx_stream_get_index_ptr = dlsym(handle, "mlx_stream_get_index"); + if (mlx_stream_get_index_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_get_index\n"); + return -1; + } + mlx_synchronize_ptr = dlsym(handle, "mlx_synchronize"); + if (mlx_synchronize_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_synchronize\n"); + return -1; + } + mlx_get_default_stream_ptr = dlsym(handle, "mlx_get_default_stream"); + if (mlx_get_default_stream_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_get_default_stream\n"); + return -1; + } + mlx_set_default_stream_ptr = dlsym(handle, "mlx_set_default_stream"); + if (mlx_set_default_stream_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_set_default_stream\n"); + return -1; + } + mlx_default_cpu_stream_new_ptr = dlsym(handle, "mlx_default_cpu_stream_new"); + if (mlx_default_cpu_stream_new_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_default_cpu_stream_new\n"); + return -1; + } + mlx_default_gpu_stream_new_ptr = dlsym(handle, "mlx_default_gpu_stream_new"); + if (mlx_default_gpu_stream_new_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_default_gpu_stream_new\n"); + return -1; + } + mlx_string_new_ptr = dlsym(handle, "mlx_string_new"); + if (mlx_string_new_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_string_new\n"); + return -1; + } + mlx_string_new_data_ptr = dlsym(handle, "mlx_string_new_data"); + if (mlx_string_new_data_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_string_new_data\n"); + return -1; + } + mlx_string_set_ptr = dlsym(handle, "mlx_string_set"); + if (mlx_string_set_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_string_set\n"); + return -1; + } + mlx_string_data_ptr = dlsym(handle, "mlx_string_data"); + if (mlx_string_data_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_string_data\n"); + return -1; + } + mlx_string_free_ptr = dlsym(handle, "mlx_string_free"); + if (mlx_string_free_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_string_free\n"); + return -1; + } + mlx_async_eval_ptr = dlsym(handle, "mlx_async_eval"); + if (mlx_async_eval_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_async_eval\n"); + return -1; + } + mlx_checkpoint_ptr = dlsym(handle, "mlx_checkpoint"); + if (mlx_checkpoint_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_checkpoint\n"); + return -1; + } + mlx_custom_function_ptr = dlsym(handle, "mlx_custom_function"); + if (mlx_custom_function_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_custom_function\n"); + return -1; + } + mlx_custom_vjp_ptr = dlsym(handle, "mlx_custom_vjp"); + if (mlx_custom_vjp_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_custom_vjp\n"); + return -1; + } + mlx_eval_ptr = dlsym(handle, "mlx_eval"); + if (mlx_eval_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_eval\n"); + return -1; + } + mlx_jvp_ptr = dlsym(handle, "mlx_jvp"); + if (mlx_jvp_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_jvp\n"); + return -1; + } + mlx_value_and_grad_ptr = dlsym(handle, "mlx_value_and_grad"); + if (mlx_value_and_grad_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_value_and_grad\n"); + return -1; + } + mlx_vjp_ptr = dlsym(handle, "mlx_vjp"); + if (mlx_vjp_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vjp\n"); + return -1; + } + mlx_detail_vmap_replace_ptr = dlsym(handle, "mlx_detail_vmap_replace"); + if (mlx_detail_vmap_replace_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_detail_vmap_replace\n"); + return -1; + } + mlx_detail_vmap_trace_ptr = dlsym(handle, "mlx_detail_vmap_trace"); + if (mlx_detail_vmap_trace_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_detail_vmap_trace\n"); + return -1; + } + mlx_vector_array_new_ptr = dlsym(handle, "mlx_vector_array_new"); + if (mlx_vector_array_new_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_new\n"); + return -1; + } + mlx_vector_array_set_ptr = dlsym(handle, "mlx_vector_array_set"); + if (mlx_vector_array_set_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_set\n"); + return -1; + } + mlx_vector_array_free_ptr = dlsym(handle, "mlx_vector_array_free"); + if (mlx_vector_array_free_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_free\n"); + return -1; + } + mlx_vector_array_new_data_ptr = dlsym(handle, "mlx_vector_array_new_data"); + if (mlx_vector_array_new_data_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_new_data\n"); + return -1; + } + mlx_vector_array_new_value_ptr = dlsym(handle, "mlx_vector_array_new_value"); + if (mlx_vector_array_new_value_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_new_value\n"); + return -1; + } + mlx_vector_array_set_data_ptr = dlsym(handle, "mlx_vector_array_set_data"); + if (mlx_vector_array_set_data_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_set_data\n"); + return -1; + } + mlx_vector_array_set_value_ptr = dlsym(handle, "mlx_vector_array_set_value"); + if (mlx_vector_array_set_value_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_set_value\n"); + return -1; + } + mlx_vector_array_append_data_ptr = dlsym(handle, "mlx_vector_array_append_data"); + if (mlx_vector_array_append_data_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_append_data\n"); + return -1; + } + mlx_vector_array_append_value_ptr = dlsym(handle, "mlx_vector_array_append_value"); + if (mlx_vector_array_append_value_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_append_value\n"); + return -1; + } + mlx_vector_array_size_ptr = dlsym(handle, "mlx_vector_array_size"); + if (mlx_vector_array_size_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_size\n"); + return -1; + } + mlx_vector_array_get_ptr = dlsym(handle, "mlx_vector_array_get"); + if (mlx_vector_array_get_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_get\n"); + return -1; + } + mlx_vector_vector_array_new_ptr = dlsym(handle, "mlx_vector_vector_array_new"); + if (mlx_vector_vector_array_new_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_new\n"); + return -1; + } + mlx_vector_vector_array_set_ptr = dlsym(handle, "mlx_vector_vector_array_set"); + if (mlx_vector_vector_array_set_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_set\n"); + return -1; + } + mlx_vector_vector_array_free_ptr = dlsym(handle, "mlx_vector_vector_array_free"); + if (mlx_vector_vector_array_free_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_free\n"); + return -1; + } + mlx_vector_vector_array_new_data_ptr = dlsym(handle, "mlx_vector_vector_array_new_data"); + if (mlx_vector_vector_array_new_data_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_new_data\n"); + return -1; + } + mlx_vector_vector_array_new_value_ptr = dlsym(handle, "mlx_vector_vector_array_new_value"); + if (mlx_vector_vector_array_new_value_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_new_value\n"); + return -1; + } + mlx_vector_vector_array_set_data_ptr = dlsym(handle, "mlx_vector_vector_array_set_data"); + if (mlx_vector_vector_array_set_data_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_set_data\n"); + return -1; + } + mlx_vector_vector_array_set_value_ptr = dlsym(handle, "mlx_vector_vector_array_set_value"); + if (mlx_vector_vector_array_set_value_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_set_value\n"); + return -1; + } + mlx_vector_vector_array_append_data_ptr = dlsym(handle, "mlx_vector_vector_array_append_data"); + if (mlx_vector_vector_array_append_data_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_append_data\n"); + return -1; + } + mlx_vector_vector_array_append_value_ptr = dlsym(handle, "mlx_vector_vector_array_append_value"); + if (mlx_vector_vector_array_append_value_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_append_value\n"); + return -1; + } + mlx_vector_vector_array_size_ptr = dlsym(handle, "mlx_vector_vector_array_size"); + if (mlx_vector_vector_array_size_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_size\n"); + return -1; + } + mlx_vector_vector_array_get_ptr = dlsym(handle, "mlx_vector_vector_array_get"); + if (mlx_vector_vector_array_get_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_get\n"); + return -1; + } + mlx_vector_int_new_ptr = dlsym(handle, "mlx_vector_int_new"); + if (mlx_vector_int_new_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_new\n"); + return -1; + } + mlx_vector_int_set_ptr = dlsym(handle, "mlx_vector_int_set"); + if (mlx_vector_int_set_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_set\n"); + return -1; + } + mlx_vector_int_free_ptr = dlsym(handle, "mlx_vector_int_free"); + if (mlx_vector_int_free_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_free\n"); + return -1; + } + mlx_vector_int_new_data_ptr = dlsym(handle, "mlx_vector_int_new_data"); + if (mlx_vector_int_new_data_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_new_data\n"); + return -1; + } + mlx_vector_int_new_value_ptr = dlsym(handle, "mlx_vector_int_new_value"); + if (mlx_vector_int_new_value_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_new_value\n"); + return -1; + } + mlx_vector_int_set_data_ptr = dlsym(handle, "mlx_vector_int_set_data"); + if (mlx_vector_int_set_data_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_set_data\n"); + return -1; + } + mlx_vector_int_set_value_ptr = dlsym(handle, "mlx_vector_int_set_value"); + if (mlx_vector_int_set_value_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_set_value\n"); + return -1; + } + mlx_vector_int_append_data_ptr = dlsym(handle, "mlx_vector_int_append_data"); + if (mlx_vector_int_append_data_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_append_data\n"); + return -1; + } + mlx_vector_int_append_value_ptr = dlsym(handle, "mlx_vector_int_append_value"); + if (mlx_vector_int_append_value_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_append_value\n"); + return -1; + } + mlx_vector_int_size_ptr = dlsym(handle, "mlx_vector_int_size"); + if (mlx_vector_int_size_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_size\n"); + return -1; + } + mlx_vector_int_get_ptr = dlsym(handle, "mlx_vector_int_get"); + if (mlx_vector_int_get_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_get\n"); + return -1; + } + mlx_vector_string_new_ptr = dlsym(handle, "mlx_vector_string_new"); + if (mlx_vector_string_new_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_new\n"); + return -1; + } + mlx_vector_string_set_ptr = dlsym(handle, "mlx_vector_string_set"); + if (mlx_vector_string_set_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_set\n"); + return -1; + } + mlx_vector_string_free_ptr = dlsym(handle, "mlx_vector_string_free"); + if (mlx_vector_string_free_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_free\n"); + return -1; + } + mlx_vector_string_new_data_ptr = dlsym(handle, "mlx_vector_string_new_data"); + if (mlx_vector_string_new_data_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_new_data\n"); + return -1; + } + mlx_vector_string_new_value_ptr = dlsym(handle, "mlx_vector_string_new_value"); + if (mlx_vector_string_new_value_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_new_value\n"); + return -1; + } + mlx_vector_string_set_data_ptr = dlsym(handle, "mlx_vector_string_set_data"); + if (mlx_vector_string_set_data_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_set_data\n"); + return -1; + } + mlx_vector_string_set_value_ptr = dlsym(handle, "mlx_vector_string_set_value"); + if (mlx_vector_string_set_value_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_set_value\n"); + return -1; + } + mlx_vector_string_append_data_ptr = dlsym(handle, "mlx_vector_string_append_data"); + if (mlx_vector_string_append_data_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_append_data\n"); + return -1; + } + mlx_vector_string_append_value_ptr = dlsym(handle, "mlx_vector_string_append_value"); + if (mlx_vector_string_append_value_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_append_value\n"); + return -1; + } + mlx_vector_string_size_ptr = dlsym(handle, "mlx_vector_string_size"); + if (mlx_vector_string_size_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_size\n"); + return -1; + } + mlx_vector_string_get_ptr = dlsym(handle, "mlx_vector_string_get"); + if (mlx_vector_string_get_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_get\n"); + return -1; + } + mlx_version_ptr = dlsym(handle, "mlx_version"); + if (mlx_version_ptr == NULL) { + fprintf(stderr, "MLX: Failed to load symbol: mlx_version\n"); + return -1; + } + return 0; +} + +// Wrapper function implementations that call through function pointers +size_t mlx_dtype_size(mlx_dtype dtype) { + return mlx_dtype_size_ptr(dtype); +} + +int mlx_array_tostring(mlx_string* str, const mlx_array arr) { + return mlx_array_tostring_ptr(str, arr); +} + +mlx_array mlx_array_new(void) { + return mlx_array_new_ptr(); +} + +int mlx_array_free(mlx_array arr) { + return mlx_array_free_ptr(arr); +} + +mlx_array mlx_array_new_bool(bool val) { + return mlx_array_new_bool_ptr(val); +} + +mlx_array mlx_array_new_int(int val) { + return mlx_array_new_int_ptr(val); +} + +mlx_array mlx_array_new_float32(float val) { + return mlx_array_new_float32_ptr(val); +} + +mlx_array mlx_array_new_float(float val) { + return mlx_array_new_float_ptr(val); +} + +mlx_array mlx_array_new_float64(double val) { + return mlx_array_new_float64_ptr(val); +} + +mlx_array mlx_array_new_double(double val) { + return mlx_array_new_double_ptr(val); +} + +mlx_array mlx_array_new_complex(float real_val, float imag_val) { + return mlx_array_new_complex_ptr(real_val, imag_val); +} + +mlx_array mlx_array_new_data(const void* data, const int* shape, int dim, mlx_dtype dtype) { + return mlx_array_new_data_ptr(data, shape, dim, dtype); +} + +int mlx_array_set(mlx_array* arr, const mlx_array src) { + return mlx_array_set_ptr(arr, src); +} + +int mlx_array_set_bool(mlx_array* arr, bool val) { + return mlx_array_set_bool_ptr(arr, val); +} + +int mlx_array_set_int(mlx_array* arr, int val) { + return mlx_array_set_int_ptr(arr, val); +} + +int mlx_array_set_float32(mlx_array* arr, float val) { + return mlx_array_set_float32_ptr(arr, val); +} + +int mlx_array_set_float(mlx_array* arr, float val) { + return mlx_array_set_float_ptr(arr, val); +} + +int mlx_array_set_float64(mlx_array* arr, double val) { + return mlx_array_set_float64_ptr(arr, val); +} + +int mlx_array_set_double(mlx_array* arr, double val) { + return mlx_array_set_double_ptr(arr, val); +} + +int mlx_array_set_complex(mlx_array* arr, float real_val, float imag_val) { + return mlx_array_set_complex_ptr(arr, real_val, imag_val); +} + +int mlx_array_set_data(mlx_array* arr, const void* data, const int* shape, int dim, mlx_dtype dtype) { + return mlx_array_set_data_ptr(arr, data, shape, dim, dtype); +} + +size_t mlx_array_itemsize(const mlx_array arr) { + return mlx_array_itemsize_ptr(arr); +} + +size_t mlx_array_size(const mlx_array arr) { + return mlx_array_size_ptr(arr); +} + +size_t mlx_array_nbytes(const mlx_array arr) { + return mlx_array_nbytes_ptr(arr); +} + +size_t mlx_array_ndim(const mlx_array arr) { + return mlx_array_ndim_ptr(arr); +} + +const int* mlx_array_shape(const mlx_array arr) { + return mlx_array_shape_ptr(arr); +} + +const size_t* mlx_array_strides(const mlx_array arr) { + return mlx_array_strides_ptr(arr); +} + +int mlx_array_dim(const mlx_array arr, int dim) { + return mlx_array_dim_ptr(arr, dim); +} + +mlx_dtype mlx_array_dtype(const mlx_array arr) { + return mlx_array_dtype_ptr(arr); +} + +int mlx_array_eval(mlx_array arr) { + return mlx_array_eval_ptr(arr); +} + +int mlx_array_item_bool(bool* res, const mlx_array arr) { + return mlx_array_item_bool_ptr(res, arr); +} + +int mlx_array_item_uint8(uint8_t* res, const mlx_array arr) { + return mlx_array_item_uint8_ptr(res, arr); +} + +int mlx_array_item_uint16(uint16_t* res, const mlx_array arr) { + return mlx_array_item_uint16_ptr(res, arr); +} + +int mlx_array_item_uint32(uint32_t* res, const mlx_array arr) { + return mlx_array_item_uint32_ptr(res, arr); +} + +int mlx_array_item_uint64(uint64_t* res, const mlx_array arr) { + return mlx_array_item_uint64_ptr(res, arr); +} + +int mlx_array_item_int8(int8_t* res, const mlx_array arr) { + return mlx_array_item_int8_ptr(res, arr); +} + +int mlx_array_item_int16(int16_t* res, const mlx_array arr) { + return mlx_array_item_int16_ptr(res, arr); +} + +int mlx_array_item_int32(int32_t* res, const mlx_array arr) { + return mlx_array_item_int32_ptr(res, arr); +} + +int mlx_array_item_int64(int64_t* res, const mlx_array arr) { + return mlx_array_item_int64_ptr(res, arr); +} + +int mlx_array_item_float32(float* res, const mlx_array arr) { + return mlx_array_item_float32_ptr(res, arr); +} + +int mlx_array_item_float64(double* res, const mlx_array arr) { + return mlx_array_item_float64_ptr(res, arr); +} + +int mlx_array_item_complex64(float _Complex* res, const mlx_array arr) { + return mlx_array_item_complex64_ptr(res, arr); +} + +#if defined(__aarch64__) || defined(_M_ARM64) +int mlx_array_item_float16(float16_t* res, const mlx_array arr) { + return mlx_array_item_float16_ptr(res, arr); +} +#endif + +#if defined(__aarch64__) || defined(_M_ARM64) +int mlx_array_item_bfloat16(bfloat16_t* res, const mlx_array arr) { + return mlx_array_item_bfloat16_ptr(res, arr); +} +#endif + +const bool* mlx_array_data_bool(const mlx_array arr) { + return mlx_array_data_bool_ptr(arr); +} + +const uint8_t* mlx_array_data_uint8(const mlx_array arr) { + return mlx_array_data_uint8_ptr(arr); +} + +const uint16_t* mlx_array_data_uint16(const mlx_array arr) { + return mlx_array_data_uint16_ptr(arr); +} + +const uint32_t* mlx_array_data_uint32(const mlx_array arr) { + return mlx_array_data_uint32_ptr(arr); +} + +const uint64_t* mlx_array_data_uint64(const mlx_array arr) { + return mlx_array_data_uint64_ptr(arr); +} + +const int8_t* mlx_array_data_int8(const mlx_array arr) { + return mlx_array_data_int8_ptr(arr); +} + +const int16_t* mlx_array_data_int16(const mlx_array arr) { + return mlx_array_data_int16_ptr(arr); +} + +const int32_t* mlx_array_data_int32(const mlx_array arr) { + return mlx_array_data_int32_ptr(arr); +} + +const int64_t* mlx_array_data_int64(const mlx_array arr) { + return mlx_array_data_int64_ptr(arr); +} + +const float* mlx_array_data_float32(const mlx_array arr) { + return mlx_array_data_float32_ptr(arr); +} + +const double* mlx_array_data_float64(const mlx_array arr) { + return mlx_array_data_float64_ptr(arr); +} + +const float _Complex* mlx_array_data_complex64(const mlx_array arr) { + return mlx_array_data_complex64_ptr(arr); +} + +#if defined(__aarch64__) || defined(_M_ARM64) +const float16_t* mlx_array_data_float16(const mlx_array arr) { + return mlx_array_data_float16_ptr(arr); +} +#endif + +#if defined(__aarch64__) || defined(_M_ARM64) +const bfloat16_t* mlx_array_data_bfloat16(const mlx_array arr) { + return mlx_array_data_bfloat16_ptr(arr); +} +#endif + +int _mlx_array_is_available(bool* res, const mlx_array arr) { + return _mlx_array_is_available_ptr(res, arr); +} + +int _mlx_array_wait(const mlx_array arr) { + return _mlx_array_wait_ptr(arr); +} + +int _mlx_array_is_contiguous(bool* res, const mlx_array arr) { + return _mlx_array_is_contiguous_ptr(res, arr); +} + +int _mlx_array_is_row_contiguous(bool* res, const mlx_array arr) { + return _mlx_array_is_row_contiguous_ptr(res, arr); +} + +int _mlx_array_is_col_contiguous(bool* res, const mlx_array arr) { + return _mlx_array_is_col_contiguous_ptr(res, arr); +} + +mlx_closure mlx_closure_new(void) { + return mlx_closure_new_ptr(); +} + +int mlx_closure_free(mlx_closure cls) { + return mlx_closure_free_ptr(cls); +} + +mlx_closure mlx_closure_new_func(int (*fun)(mlx_vector_array*, const mlx_vector_array)) { + return mlx_closure_new_func_ptr(fun); +} + +mlx_closure mlx_closure_new_func_payload(int (*fun)(mlx_vector_array*, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)) { + return mlx_closure_new_func_payload_ptr(fun, payload, dtor); +} + +int mlx_closure_set(mlx_closure* cls, const mlx_closure src) { + return mlx_closure_set_ptr(cls, src); +} + +int mlx_closure_apply(mlx_vector_array* res, mlx_closure cls, const mlx_vector_array input) { + return mlx_closure_apply_ptr(res, cls, input); +} + +mlx_closure mlx_closure_new_unary(int (*fun)(mlx_array*, const mlx_array)) { + return mlx_closure_new_unary_ptr(fun); +} + +mlx_closure_kwargs mlx_closure_kwargs_new(void) { + return mlx_closure_kwargs_new_ptr(); +} + +int mlx_closure_kwargs_free(mlx_closure_kwargs cls) { + return mlx_closure_kwargs_free_ptr(cls); +} + +mlx_closure_kwargs mlx_closure_kwargs_new_func(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_map_string_to_array)) { + return mlx_closure_kwargs_new_func_ptr(fun); +} + +mlx_closure_kwargs mlx_closure_kwargs_new_func_payload(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_map_string_to_array, void*), void* payload, void (*dtor)(void*)) { + return mlx_closure_kwargs_new_func_payload_ptr(fun, payload, dtor); +} + +int mlx_closure_kwargs_set(mlx_closure_kwargs* cls, const mlx_closure_kwargs src) { + return mlx_closure_kwargs_set_ptr(cls, src); +} + +int mlx_closure_kwargs_apply(mlx_vector_array* res, mlx_closure_kwargs cls, const mlx_vector_array input_0, const mlx_map_string_to_array input_1) { + return mlx_closure_kwargs_apply_ptr(res, cls, input_0, input_1); +} + +mlx_closure_value_and_grad mlx_closure_value_and_grad_new(void) { + return mlx_closure_value_and_grad_new_ptr(); +} + +int mlx_closure_value_and_grad_free(mlx_closure_value_and_grad cls) { + return mlx_closure_value_and_grad_free_ptr(cls); +} + +mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func(int (*fun)(mlx_vector_array*, mlx_vector_array*, const mlx_vector_array)) { + return mlx_closure_value_and_grad_new_func_ptr(fun); +} + +mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func_payload(int (*fun)( mlx_vector_array*, mlx_vector_array*, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)) { + return mlx_closure_value_and_grad_new_func_payload_ptr(fun, payload, dtor); +} + +int mlx_closure_value_and_grad_set(mlx_closure_value_and_grad* cls, const mlx_closure_value_and_grad src) { + return mlx_closure_value_and_grad_set_ptr(cls, src); +} + +int mlx_closure_value_and_grad_apply(mlx_vector_array* res_0, mlx_vector_array* res_1, mlx_closure_value_and_grad cls, const mlx_vector_array input) { + return mlx_closure_value_and_grad_apply_ptr(res_0, res_1, cls, input); +} + +mlx_closure_custom mlx_closure_custom_new(void) { + return mlx_closure_custom_new_ptr(); +} + +int mlx_closure_custom_free(mlx_closure_custom cls) { + return mlx_closure_custom_free_ptr(cls); +} + +mlx_closure_custom mlx_closure_custom_new_func(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const mlx_vector_array)) { + return mlx_closure_custom_new_func_ptr(fun); +} + +mlx_closure_custom mlx_closure_custom_new_func_payload(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)) { + return mlx_closure_custom_new_func_payload_ptr(fun, payload, dtor); +} + +int mlx_closure_custom_set(mlx_closure_custom* cls, const mlx_closure_custom src) { + return mlx_closure_custom_set_ptr(cls, src); +} + +int mlx_closure_custom_apply(mlx_vector_array* res, mlx_closure_custom cls, const mlx_vector_array input_0, const mlx_vector_array input_1, const mlx_vector_array input_2) { + return mlx_closure_custom_apply_ptr(res, cls, input_0, input_1, input_2); +} + +mlx_closure_custom_jvp mlx_closure_custom_jvp_new(void) { + return mlx_closure_custom_jvp_new_ptr(); +} + +int mlx_closure_custom_jvp_free(mlx_closure_custom_jvp cls) { + return mlx_closure_custom_jvp_free_ptr(cls); +} + +mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const int*, size_t _num)) { + return mlx_closure_custom_jvp_new_func_ptr(fun); +} + +mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func_payload(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const int*, size_t _num, void*), void* payload, void (*dtor)(void*)) { + return mlx_closure_custom_jvp_new_func_payload_ptr(fun, payload, dtor); +} + +int mlx_closure_custom_jvp_set(mlx_closure_custom_jvp* cls, const mlx_closure_custom_jvp src) { + return mlx_closure_custom_jvp_set_ptr(cls, src); +} + +int mlx_closure_custom_jvp_apply(mlx_vector_array* res, mlx_closure_custom_jvp cls, const mlx_vector_array input_0, const mlx_vector_array input_1, const int* input_2, size_t input_2_num) { + return mlx_closure_custom_jvp_apply_ptr(res, cls, input_0, input_1, input_2, input_2_num); +} + +mlx_closure_custom_vmap mlx_closure_custom_vmap_new(void) { + return mlx_closure_custom_vmap_new_ptr(); +} + +int mlx_closure_custom_vmap_free(mlx_closure_custom_vmap cls) { + return mlx_closure_custom_vmap_free_ptr(cls); +} + +mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func(int (*fun)( mlx_vector_array*, mlx_vector_int*, const mlx_vector_array, const int*, size_t _num)) { + return mlx_closure_custom_vmap_new_func_ptr(fun); +} + +mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func_payload(int (*fun)( mlx_vector_array*, mlx_vector_int*, const mlx_vector_array, const int*, size_t _num, void*), void* payload, void (*dtor)(void*)) { + return mlx_closure_custom_vmap_new_func_payload_ptr(fun, payload, dtor); +} + +int mlx_closure_custom_vmap_set(mlx_closure_custom_vmap* cls, const mlx_closure_custom_vmap src) { + return mlx_closure_custom_vmap_set_ptr(cls, src); +} + +int mlx_closure_custom_vmap_apply(mlx_vector_array* res_0, mlx_vector_int* res_1, mlx_closure_custom_vmap cls, const mlx_vector_array input_0, const int* input_1, size_t input_1_num) { + return mlx_closure_custom_vmap_apply_ptr(res_0, res_1, cls, input_0, input_1, input_1_num); +} + +int mlx_compile(mlx_closure* res, const mlx_closure fun, bool shapeless) { + return mlx_compile_ptr(res, fun, shapeless); +} + +int mlx_detail_compile(mlx_closure* res, const mlx_closure fun, uintptr_t fun_id, bool shapeless, const uint64_t* constants, size_t constants_num) { + return mlx_detail_compile_ptr(res, fun, fun_id, shapeless, constants, constants_num); +} + +int mlx_detail_compile_clear_cache(void) { + return mlx_detail_compile_clear_cache_ptr(); +} + +int mlx_detail_compile_erase(uintptr_t fun_id) { + return mlx_detail_compile_erase_ptr(fun_id); +} + +int mlx_disable_compile(void) { + return mlx_disable_compile_ptr(); +} + +int mlx_enable_compile(void) { + return mlx_enable_compile_ptr(); +} + +int mlx_set_compile_mode(mlx_compile_mode mode) { + return mlx_set_compile_mode_ptr(mode); +} + +mlx_device mlx_device_new(void) { + return mlx_device_new_ptr(); +} + +mlx_device mlx_device_new_type(mlx_device_type type, int index) { + return mlx_device_new_type_ptr(type, index); +} + +int mlx_device_free(mlx_device dev) { + return mlx_device_free_ptr(dev); +} + +int mlx_device_set(mlx_device* dev, const mlx_device src) { + return mlx_device_set_ptr(dev, src); +} + +int mlx_device_tostring(mlx_string* str, mlx_device dev) { + return mlx_device_tostring_ptr(str, dev); +} + +bool mlx_device_equal(mlx_device lhs, mlx_device rhs) { + return mlx_device_equal_ptr(lhs, rhs); +} + +int mlx_device_get_index(int* index, mlx_device dev) { + return mlx_device_get_index_ptr(index, dev); +} + +int mlx_device_get_type(mlx_device_type* type, mlx_device dev) { + return mlx_device_get_type_ptr(type, dev); +} + +int mlx_get_default_device(mlx_device* dev) { + return mlx_get_default_device_ptr(dev); +} + +int mlx_set_default_device(mlx_device dev) { + return mlx_set_default_device_ptr(dev); +} + +int mlx_distributed_all_gather(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S) { + return mlx_distributed_all_gather_ptr(res, x, group, S); +} + +int mlx_distributed_all_max(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) { + return mlx_distributed_all_max_ptr(res, x, group, s); +} + +int mlx_distributed_all_min(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) { + return mlx_distributed_all_min_ptr(res, x, group, s); +} + +int mlx_distributed_all_sum(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) { + return mlx_distributed_all_sum_ptr(res, x, group, s); +} + +int mlx_distributed_recv(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, int src, const mlx_distributed_group group , const mlx_stream s) { + return mlx_distributed_recv_ptr(res, shape, shape_num, dtype, src, group, s); +} + +int mlx_distributed_recv_like(mlx_array* res, const mlx_array x, int src, const mlx_distributed_group group , const mlx_stream s) { + return mlx_distributed_recv_like_ptr(res, x, src, group, s); +} + +int mlx_distributed_send(mlx_array* res, const mlx_array x, int dst, const mlx_distributed_group group , const mlx_stream s) { + return mlx_distributed_send_ptr(res, x, dst, group, s); +} + +int mlx_distributed_sum_scatter(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) { + return mlx_distributed_sum_scatter_ptr(res, x, group, s); +} + +int mlx_distributed_group_rank(mlx_distributed_group group) { + return mlx_distributed_group_rank_ptr(group); +} + +int mlx_distributed_group_size(mlx_distributed_group group) { + return mlx_distributed_group_size_ptr(group); +} + +mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key) { + return mlx_distributed_group_split_ptr(group, color, key); +} + +bool mlx_distributed_is_available(void) { + return mlx_distributed_is_available_ptr(); +} + +mlx_distributed_group mlx_distributed_init(bool strict) { + return mlx_distributed_init_ptr(strict); +} + +void mlx_set_error_handler(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) { + mlx_set_error_handler_ptr(handler, data, dtor); +} + +void _mlx_error(const char* file, const int line, const char* fmt, ...) { + _mlx_error_ptr(file, line, fmt); +} + +int mlx_export_function(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless) { + return mlx_export_function_ptr(file, fun, args, shapeless); +} + +int mlx_export_function_kwargs(const char* file, const mlx_closure_kwargs fun, const mlx_vector_array args, const mlx_map_string_to_array kwargs, bool shapeless) { + return mlx_export_function_kwargs_ptr(file, fun, args, kwargs, shapeless); +} + +mlx_function_exporter mlx_function_exporter_new(const char* file, const mlx_closure fun, bool shapeless) { + return mlx_function_exporter_new_ptr(file, fun, shapeless); +} + +int mlx_function_exporter_free(mlx_function_exporter xfunc) { + return mlx_function_exporter_free_ptr(xfunc); +} + +int mlx_function_exporter_apply(const mlx_function_exporter xfunc, const mlx_vector_array args) { + return mlx_function_exporter_apply_ptr(xfunc, args); +} + +int mlx_function_exporter_apply_kwargs(const mlx_function_exporter xfunc, const mlx_vector_array args, const mlx_map_string_to_array kwargs) { + return mlx_function_exporter_apply_kwargs_ptr(xfunc, args, kwargs); +} + +mlx_imported_function mlx_imported_function_new(const char* file) { + return mlx_imported_function_new_ptr(file); +} + +int mlx_imported_function_free(mlx_imported_function xfunc) { + return mlx_imported_function_free_ptr(xfunc); +} + +int mlx_imported_function_apply(mlx_vector_array* res, const mlx_imported_function xfunc, const mlx_vector_array args) { + return mlx_imported_function_apply_ptr(res, xfunc, args); +} + +int mlx_imported_function_apply_kwargs(mlx_vector_array* res, const mlx_imported_function xfunc, const mlx_vector_array args, const mlx_map_string_to_array kwargs) { + return mlx_imported_function_apply_kwargs_ptr(res, xfunc, args, kwargs); +} + +mlx_fast_cuda_kernel_config mlx_fast_cuda_kernel_config_new(void) { + return mlx_fast_cuda_kernel_config_new_ptr(); +} + +void mlx_fast_cuda_kernel_config_free(mlx_fast_cuda_kernel_config cls) { + mlx_fast_cuda_kernel_config_free_ptr(cls); +} + +int mlx_fast_cuda_kernel_config_add_output_arg(mlx_fast_cuda_kernel_config cls, const int* shape, size_t size, mlx_dtype dtype) { + return mlx_fast_cuda_kernel_config_add_output_arg_ptr(cls, shape, size, dtype); +} + +int mlx_fast_cuda_kernel_config_set_grid(mlx_fast_cuda_kernel_config cls, int grid1, int grid2, int grid3) { + return mlx_fast_cuda_kernel_config_set_grid_ptr(cls, grid1, grid2, grid3); +} + +int mlx_fast_cuda_kernel_config_set_thread_group(mlx_fast_cuda_kernel_config cls, int thread1, int thread2, int thread3) { + return mlx_fast_cuda_kernel_config_set_thread_group_ptr(cls, thread1, thread2, thread3); +} + +int mlx_fast_cuda_kernel_config_set_init_value(mlx_fast_cuda_kernel_config cls, float value) { + return mlx_fast_cuda_kernel_config_set_init_value_ptr(cls, value); +} + +int mlx_fast_cuda_kernel_config_set_verbose(mlx_fast_cuda_kernel_config cls, bool verbose) { + return mlx_fast_cuda_kernel_config_set_verbose_ptr(cls, verbose); +} + +int mlx_fast_cuda_kernel_config_add_template_arg_dtype(mlx_fast_cuda_kernel_config cls, const char* name, mlx_dtype dtype) { + return mlx_fast_cuda_kernel_config_add_template_arg_dtype_ptr(cls, name, dtype); +} + +int mlx_fast_cuda_kernel_config_add_template_arg_int(mlx_fast_cuda_kernel_config cls, const char* name, int value) { + return mlx_fast_cuda_kernel_config_add_template_arg_int_ptr(cls, name, value); +} + +int mlx_fast_cuda_kernel_config_add_template_arg_bool(mlx_fast_cuda_kernel_config cls, const char* name, bool value) { + return mlx_fast_cuda_kernel_config_add_template_arg_bool_ptr(cls, name, value); +} + +mlx_fast_cuda_kernel mlx_fast_cuda_kernel_new(const char* name, const mlx_vector_string input_names, const mlx_vector_string output_names, const char* source, const char* header, bool ensure_row_contiguous, int shared_memory) { + return mlx_fast_cuda_kernel_new_ptr(name, input_names, output_names, source, header, ensure_row_contiguous, shared_memory); +} + +void mlx_fast_cuda_kernel_free(mlx_fast_cuda_kernel cls) { + mlx_fast_cuda_kernel_free_ptr(cls); +} + +int mlx_fast_cuda_kernel_apply(mlx_vector_array* outputs, mlx_fast_cuda_kernel cls, const mlx_vector_array inputs, const mlx_fast_cuda_kernel_config config, const mlx_stream stream) { + return mlx_fast_cuda_kernel_apply_ptr(outputs, cls, inputs, config, stream); +} + +int mlx_fast_layer_norm(mlx_array* res, const mlx_array x, const mlx_array weight , const mlx_array bias , float eps, const mlx_stream s) { + return mlx_fast_layer_norm_ptr(res, x, weight, bias, eps, s); +} + +mlx_fast_metal_kernel_config mlx_fast_metal_kernel_config_new(void) { + return mlx_fast_metal_kernel_config_new_ptr(); +} + +void mlx_fast_metal_kernel_config_free(mlx_fast_metal_kernel_config cls) { + mlx_fast_metal_kernel_config_free_ptr(cls); +} + +int mlx_fast_metal_kernel_config_add_output_arg(mlx_fast_metal_kernel_config cls, const int* shape, size_t size, mlx_dtype dtype) { + return mlx_fast_metal_kernel_config_add_output_arg_ptr(cls, shape, size, dtype); +} + +int mlx_fast_metal_kernel_config_set_grid(mlx_fast_metal_kernel_config cls, int grid1, int grid2, int grid3) { + return mlx_fast_metal_kernel_config_set_grid_ptr(cls, grid1, grid2, grid3); +} + +int mlx_fast_metal_kernel_config_set_thread_group(mlx_fast_metal_kernel_config cls, int thread1, int thread2, int thread3) { + return mlx_fast_metal_kernel_config_set_thread_group_ptr(cls, thread1, thread2, thread3); +} + +int mlx_fast_metal_kernel_config_set_init_value(mlx_fast_metal_kernel_config cls, float value) { + return mlx_fast_metal_kernel_config_set_init_value_ptr(cls, value); +} + +int mlx_fast_metal_kernel_config_set_verbose(mlx_fast_metal_kernel_config cls, bool verbose) { + return mlx_fast_metal_kernel_config_set_verbose_ptr(cls, verbose); +} + +int mlx_fast_metal_kernel_config_add_template_arg_dtype(mlx_fast_metal_kernel_config cls, const char* name, mlx_dtype dtype) { + return mlx_fast_metal_kernel_config_add_template_arg_dtype_ptr(cls, name, dtype); +} + +int mlx_fast_metal_kernel_config_add_template_arg_int(mlx_fast_metal_kernel_config cls, const char* name, int value) { + return mlx_fast_metal_kernel_config_add_template_arg_int_ptr(cls, name, value); +} + +int mlx_fast_metal_kernel_config_add_template_arg_bool(mlx_fast_metal_kernel_config cls, const char* name, bool value) { + return mlx_fast_metal_kernel_config_add_template_arg_bool_ptr(cls, name, value); +} + +mlx_fast_metal_kernel mlx_fast_metal_kernel_new(const char* name, const mlx_vector_string input_names, const mlx_vector_string output_names, const char* source, const char* header, bool ensure_row_contiguous, bool atomic_outputs) { + return mlx_fast_metal_kernel_new_ptr(name, input_names, output_names, source, header, ensure_row_contiguous, atomic_outputs); +} + +void mlx_fast_metal_kernel_free(mlx_fast_metal_kernel cls) { + mlx_fast_metal_kernel_free_ptr(cls); +} + +int mlx_fast_metal_kernel_apply(mlx_vector_array* outputs, mlx_fast_metal_kernel cls, const mlx_vector_array inputs, const mlx_fast_metal_kernel_config config, const mlx_stream stream) { + return mlx_fast_metal_kernel_apply_ptr(outputs, cls, inputs, config, stream); +} + +int mlx_fast_rms_norm(mlx_array* res, const mlx_array x, const mlx_array weight , float eps, const mlx_stream s) { + return mlx_fast_rms_norm_ptr(res, x, weight, eps, s); +} + +int mlx_fast_rope(mlx_array* res, const mlx_array x, int dims, bool traditional, mlx_optional_float base, float scale, int offset, const mlx_array freqs , const mlx_stream s) { + return mlx_fast_rope_ptr(res, x, dims, traditional, base, scale, offset, freqs, s); +} + +int mlx_fast_rope_dynamic(mlx_array* res, const mlx_array x, int dims, bool traditional, mlx_optional_float base, float scale, const mlx_array offset, const mlx_array freqs , const mlx_stream s) { + return mlx_fast_rope_dynamic_ptr(res, x, dims, traditional, base, scale, offset, freqs, s); +} + +int mlx_fast_scaled_dot_product_attention(mlx_array* res, const mlx_array queries, const mlx_array keys, const mlx_array values, float scale, const char* mask_mode, const mlx_array mask_arr , const mlx_array sinks , const mlx_stream s) { + return mlx_fast_scaled_dot_product_attention_ptr(res, queries, keys, values, scale, mask_mode, mask_arr, sinks, s); +} + +int mlx_fft_fft(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) { + return mlx_fft_fft_ptr(res, a, n, axis, s); +} + +int mlx_fft_fft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) { + return mlx_fft_fft2_ptr(res, a, n, n_num, axes, axes_num, s); +} + +int mlx_fft_fftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) { + return mlx_fft_fftn_ptr(res, a, n, n_num, axes, axes_num, s); +} + +int mlx_fft_fftshift(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) { + return mlx_fft_fftshift_ptr(res, a, axes, axes_num, s); +} + +int mlx_fft_ifft(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) { + return mlx_fft_ifft_ptr(res, a, n, axis, s); +} + +int mlx_fft_ifft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) { + return mlx_fft_ifft2_ptr(res, a, n, n_num, axes, axes_num, s); +} + +int mlx_fft_ifftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) { + return mlx_fft_ifftn_ptr(res, a, n, n_num, axes, axes_num, s); +} + +int mlx_fft_ifftshift(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) { + return mlx_fft_ifftshift_ptr(res, a, axes, axes_num, s); +} + +int mlx_fft_irfft(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) { + return mlx_fft_irfft_ptr(res, a, n, axis, s); +} + +int mlx_fft_irfft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) { + return mlx_fft_irfft2_ptr(res, a, n, n_num, axes, axes_num, s); +} + +int mlx_fft_irfftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) { + return mlx_fft_irfftn_ptr(res, a, n, n_num, axes, axes_num, s); +} + +int mlx_fft_rfft(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) { + return mlx_fft_rfft_ptr(res, a, n, axis, s); +} + +int mlx_fft_rfft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) { + return mlx_fft_rfft2_ptr(res, a, n, n_num, axes, axes_num, s); +} + +int mlx_fft_rfftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) { + return mlx_fft_rfftn_ptr(res, a, n, n_num, axes, axes_num, s); +} + +int mlx_load_reader(mlx_array* res, mlx_io_reader in_stream, const mlx_stream s) { + return mlx_load_reader_ptr(res, in_stream, s); +} + +int mlx_load(mlx_array* res, const char* file, const mlx_stream s) { + return mlx_load_ptr(res, file, s); +} + +int mlx_load_safetensors_reader(mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, mlx_io_reader in_stream, const mlx_stream s) { + return mlx_load_safetensors_reader_ptr(res_0, res_1, in_stream, s); +} + +int mlx_load_safetensors(mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, const char* file, const mlx_stream s) { + return mlx_load_safetensors_ptr(res_0, res_1, file, s); +} + +int mlx_save_writer(mlx_io_writer out_stream, const mlx_array a) { + return mlx_save_writer_ptr(out_stream, a); +} + +int mlx_save(const char* file, const mlx_array a) { + return mlx_save_ptr(file, a); +} + +int mlx_save_safetensors_writer(mlx_io_writer in_stream, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata) { + return mlx_save_safetensors_writer_ptr(in_stream, param, metadata); +} + +int mlx_save_safetensors(const char* file, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata) { + return mlx_save_safetensors_ptr(file, param, metadata); +} + +mlx_io_reader mlx_io_reader_new(void* desc, mlx_io_vtable vtable) { + return mlx_io_reader_new_ptr(desc, vtable); +} + +int mlx_io_reader_descriptor(void** desc_, mlx_io_reader io) { + return mlx_io_reader_descriptor_ptr(desc_, io); +} + +int mlx_io_reader_tostring(mlx_string* str_, mlx_io_reader io) { + return mlx_io_reader_tostring_ptr(str_, io); +} + +int mlx_io_reader_free(mlx_io_reader io) { + return mlx_io_reader_free_ptr(io); +} + +mlx_io_writer mlx_io_writer_new(void* desc, mlx_io_vtable vtable) { + return mlx_io_writer_new_ptr(desc, vtable); +} + +int mlx_io_writer_descriptor(void** desc_, mlx_io_writer io) { + return mlx_io_writer_descriptor_ptr(desc_, io); +} + +int mlx_io_writer_tostring(mlx_string* str_, mlx_io_writer io) { + return mlx_io_writer_tostring_ptr(str_, io); +} + +int mlx_io_writer_free(mlx_io_writer io) { + return mlx_io_writer_free_ptr(io); +} + +int mlx_linalg_cholesky(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s) { + return mlx_linalg_cholesky_ptr(res, a, upper, s); +} + +int mlx_linalg_cholesky_inv(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s) { + return mlx_linalg_cholesky_inv_ptr(res, a, upper, s); +} + +int mlx_linalg_cross(mlx_array* res, const mlx_array a, const mlx_array b, int axis, const mlx_stream s) { + return mlx_linalg_cross_ptr(res, a, b, axis, s); +} + +int mlx_linalg_eig(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s) { + return mlx_linalg_eig_ptr(res_0, res_1, a, s); +} + +int mlx_linalg_eigh(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const char* UPLO, const mlx_stream s) { + return mlx_linalg_eigh_ptr(res_0, res_1, a, UPLO, s); +} + +int mlx_linalg_eigvals(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_linalg_eigvals_ptr(res, a, s); +} + +int mlx_linalg_eigvalsh(mlx_array* res, const mlx_array a, const char* UPLO, const mlx_stream s) { + return mlx_linalg_eigvalsh_ptr(res, a, UPLO, s); +} + +int mlx_linalg_inv(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_linalg_inv_ptr(res, a, s); +} + +int mlx_linalg_lu(mlx_vector_array* res, const mlx_array a, const mlx_stream s) { + return mlx_linalg_lu_ptr(res, a, s); +} + +int mlx_linalg_lu_factor(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s) { + return mlx_linalg_lu_factor_ptr(res_0, res_1, a, s); +} + +int mlx_linalg_norm(mlx_array* res, const mlx_array a, double ord, const int* axis , size_t axis_num, bool keepdims, const mlx_stream s) { + return mlx_linalg_norm_ptr(res, a, ord, axis, axis_num, keepdims, s); +} + +int mlx_linalg_norm_matrix(mlx_array* res, const mlx_array a, const char* ord, const int* axis , size_t axis_num, bool keepdims, const mlx_stream s) { + return mlx_linalg_norm_matrix_ptr(res, a, ord, axis, axis_num, keepdims, s); +} + +int mlx_linalg_norm_l2(mlx_array* res, const mlx_array a, const int* axis , size_t axis_num, bool keepdims, const mlx_stream s) { + return mlx_linalg_norm_l2_ptr(res, a, axis, axis_num, keepdims, s); +} + +int mlx_linalg_pinv(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_linalg_pinv_ptr(res, a, s); +} + +int mlx_linalg_qr(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s) { + return mlx_linalg_qr_ptr(res_0, res_1, a, s); +} + +int mlx_linalg_solve(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { + return mlx_linalg_solve_ptr(res, a, b, s); +} + +int mlx_linalg_solve_triangular(mlx_array* res, const mlx_array a, const mlx_array b, bool upper, const mlx_stream s) { + return mlx_linalg_solve_triangular_ptr(res, a, b, upper, s); +} + +int mlx_linalg_svd(mlx_vector_array* res, const mlx_array a, bool compute_uv, const mlx_stream s) { + return mlx_linalg_svd_ptr(res, a, compute_uv, s); +} + +int mlx_linalg_tri_inv(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s) { + return mlx_linalg_tri_inv_ptr(res, a, upper, s); +} + +mlx_map_string_to_array mlx_map_string_to_array_new(void) { + return mlx_map_string_to_array_new_ptr(); +} + +int mlx_map_string_to_array_set(mlx_map_string_to_array* map, const mlx_map_string_to_array src) { + return mlx_map_string_to_array_set_ptr(map, src); +} + +int mlx_map_string_to_array_free(mlx_map_string_to_array map) { + return mlx_map_string_to_array_free_ptr(map); +} + +int mlx_map_string_to_array_insert(mlx_map_string_to_array map, const char* key, const mlx_array value) { + return mlx_map_string_to_array_insert_ptr(map, key, value); +} + +int mlx_map_string_to_array_get(mlx_array* value, const mlx_map_string_to_array map, const char* key) { + return mlx_map_string_to_array_get_ptr(value, map, key); +} + +mlx_map_string_to_array_iterator mlx_map_string_to_array_iterator_new(mlx_map_string_to_array map) { + return mlx_map_string_to_array_iterator_new_ptr(map); +} + +int mlx_map_string_to_array_iterator_free(mlx_map_string_to_array_iterator it) { + return mlx_map_string_to_array_iterator_free_ptr(it); +} + +int mlx_map_string_to_array_iterator_next(const char** key, mlx_array* value, mlx_map_string_to_array_iterator it) { + return mlx_map_string_to_array_iterator_next_ptr(key, value, it); +} + +mlx_map_string_to_string mlx_map_string_to_string_new(void) { + return mlx_map_string_to_string_new_ptr(); +} + +int mlx_map_string_to_string_set(mlx_map_string_to_string* map, const mlx_map_string_to_string src) { + return mlx_map_string_to_string_set_ptr(map, src); +} + +int mlx_map_string_to_string_free(mlx_map_string_to_string map) { + return mlx_map_string_to_string_free_ptr(map); +} + +int mlx_map_string_to_string_insert(mlx_map_string_to_string map, const char* key, const char* value) { + return mlx_map_string_to_string_insert_ptr(map, key, value); +} + +int mlx_map_string_to_string_get(const char** value, const mlx_map_string_to_string map, const char* key) { + return mlx_map_string_to_string_get_ptr(value, map, key); +} + +mlx_map_string_to_string_iterator mlx_map_string_to_string_iterator_new(mlx_map_string_to_string map) { + return mlx_map_string_to_string_iterator_new_ptr(map); +} + +int mlx_map_string_to_string_iterator_free(mlx_map_string_to_string_iterator it) { + return mlx_map_string_to_string_iterator_free_ptr(it); +} + +int mlx_map_string_to_string_iterator_next(const char** key, const char** value, mlx_map_string_to_string_iterator it) { + return mlx_map_string_to_string_iterator_next_ptr(key, value, it); +} + +int mlx_clear_cache(void) { + return mlx_clear_cache_ptr(); +} + +int mlx_get_active_memory(size_t* res) { + return mlx_get_active_memory_ptr(res); +} + +int mlx_get_cache_memory(size_t* res) { + return mlx_get_cache_memory_ptr(res); +} + +int mlx_get_memory_limit(size_t* res) { + return mlx_get_memory_limit_ptr(res); +} + +int mlx_get_peak_memory(size_t* res) { + return mlx_get_peak_memory_ptr(res); +} + +int mlx_reset_peak_memory(void) { + return mlx_reset_peak_memory_ptr(); +} + +int mlx_set_cache_limit(size_t* res, size_t limit) { + return mlx_set_cache_limit_ptr(res, limit); +} + +int mlx_set_memory_limit(size_t* res, size_t limit) { + return mlx_set_memory_limit_ptr(res, limit); +} + +int mlx_set_wired_limit(size_t* res, size_t limit) { + return mlx_set_wired_limit_ptr(res, limit); +} + +mlx_metal_device_info_t mlx_metal_device_info(void) { + return mlx_metal_device_info_ptr(); +} + +int mlx_metal_is_available(bool* res) { + return mlx_metal_is_available_ptr(res); +} + +int mlx_metal_start_capture(const char* path) { + return mlx_metal_start_capture_ptr(path); +} + +int mlx_metal_stop_capture(void) { + return mlx_metal_stop_capture_ptr(); +} + +int mlx_abs(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_abs_ptr(res, a, s); +} + +int mlx_add(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { + return mlx_add_ptr(res, a, b, s); +} + +int mlx_addmm(mlx_array* res, const mlx_array c, const mlx_array a, const mlx_array b, float alpha, float beta, const mlx_stream s) { + return mlx_addmm_ptr(res, c, a, b, alpha, beta, s); +} + +int mlx_all_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) { + return mlx_all_axes_ptr(res, a, axes, axes_num, keepdims, s); +} + +int mlx_all_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) { + return mlx_all_axis_ptr(res, a, axis, keepdims, s); +} + +int mlx_all(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) { + return mlx_all_ptr(res, a, keepdims, s); +} + +int mlx_allclose(mlx_array* res, const mlx_array a, const mlx_array b, double rtol, double atol, bool equal_nan, const mlx_stream s) { + return mlx_allclose_ptr(res, a, b, rtol, atol, equal_nan, s); +} + +int mlx_any_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) { + return mlx_any_axes_ptr(res, a, axes, axes_num, keepdims, s); +} + +int mlx_any_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) { + return mlx_any_axis_ptr(res, a, axis, keepdims, s); +} + +int mlx_any(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) { + return mlx_any_ptr(res, a, keepdims, s); +} + +int mlx_arange(mlx_array* res, double start, double stop, double step, mlx_dtype dtype, const mlx_stream s) { + return mlx_arange_ptr(res, start, stop, step, dtype, s); +} + +int mlx_arccos(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_arccos_ptr(res, a, s); +} + +int mlx_arccosh(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_arccosh_ptr(res, a, s); +} + +int mlx_arcsin(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_arcsin_ptr(res, a, s); +} + +int mlx_arcsinh(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_arcsinh_ptr(res, a, s); +} + +int mlx_arctan(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_arctan_ptr(res, a, s); +} + +int mlx_arctan2(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { + return mlx_arctan2_ptr(res, a, b, s); +} + +int mlx_arctanh(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_arctanh_ptr(res, a, s); +} + +int mlx_argmax_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) { + return mlx_argmax_axis_ptr(res, a, axis, keepdims, s); +} + +int mlx_argmax(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) { + return mlx_argmax_ptr(res, a, keepdims, s); +} + +int mlx_argmin_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) { + return mlx_argmin_axis_ptr(res, a, axis, keepdims, s); +} + +int mlx_argmin(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) { + return mlx_argmin_ptr(res, a, keepdims, s); +} + +int mlx_argpartition_axis(mlx_array* res, const mlx_array a, int kth, int axis, const mlx_stream s) { + return mlx_argpartition_axis_ptr(res, a, kth, axis, s); +} + +int mlx_argpartition(mlx_array* res, const mlx_array a, int kth, const mlx_stream s) { + return mlx_argpartition_ptr(res, a, kth, s); +} + +int mlx_argsort_axis(mlx_array* res, const mlx_array a, int axis, const mlx_stream s) { + return mlx_argsort_axis_ptr(res, a, axis, s); +} + +int mlx_argsort(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_argsort_ptr(res, a, s); +} + +int mlx_array_equal(mlx_array* res, const mlx_array a, const mlx_array b, bool equal_nan, const mlx_stream s) { + return mlx_array_equal_ptr(res, a, b, equal_nan, s); +} + +int mlx_as_strided(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const int64_t* strides, size_t strides_num, size_t offset, const mlx_stream s) { + return mlx_as_strided_ptr(res, a, shape, shape_num, strides, strides_num, offset, s); +} + +int mlx_astype(mlx_array* res, const mlx_array a, mlx_dtype dtype, const mlx_stream s) { + return mlx_astype_ptr(res, a, dtype, s); +} + +int mlx_atleast_1d(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_atleast_1d_ptr(res, a, s); +} + +int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_atleast_2d_ptr(res, a, s); +} + +int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_atleast_3d_ptr(res, a, s); +} + +int mlx_bitwise_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { + return mlx_bitwise_and_ptr(res, a, b, s); +} + +int mlx_bitwise_invert(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_bitwise_invert_ptr(res, a, s); +} + +int mlx_bitwise_or(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { + return mlx_bitwise_or_ptr(res, a, b, s); +} + +int mlx_bitwise_xor(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { + return mlx_bitwise_xor_ptr(res, a, b, s); +} + +int mlx_block_masked_mm(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) { + return mlx_block_masked_mm_ptr(res, a, b, block_size, mask_out, mask_lhs, mask_rhs, s); +} + +int mlx_broadcast_arrays(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s) { + return mlx_broadcast_arrays_ptr(res, inputs, s); +} + +int mlx_broadcast_to(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s) { + return mlx_broadcast_to_ptr(res, a, shape, shape_num, s); +} + +int mlx_ceil(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_ceil_ptr(res, a, s); +} + +int mlx_clip(mlx_array* res, const mlx_array a, const mlx_array a_min , const mlx_array a_max , const mlx_stream s) { + return mlx_clip_ptr(res, a, a_min, a_max, s); +} + +int mlx_concatenate_axis(mlx_array* res, const mlx_vector_array arrays, int axis, const mlx_stream s) { + return mlx_concatenate_axis_ptr(res, arrays, axis, s); +} + +int mlx_concatenate(mlx_array* res, const mlx_vector_array arrays, const mlx_stream s) { + return mlx_concatenate_ptr(res, arrays, s); +} + +int mlx_conjugate(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_conjugate_ptr(res, a, s); +} + +int mlx_contiguous(mlx_array* res, const mlx_array a, bool allow_col_major, const mlx_stream s) { + return mlx_contiguous_ptr(res, a, allow_col_major, s); +} + +int mlx_conv1d(mlx_array* res, const mlx_array input, const mlx_array weight, int stride, int padding, int dilation, int groups, const mlx_stream s) { + return mlx_conv1d_ptr(res, input, weight, stride, padding, dilation, groups, s); +} + +int mlx_conv2d(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int padding_0, int padding_1, int dilation_0, int dilation_1, int groups, const mlx_stream s) { + return mlx_conv2d_ptr(res, input, weight, stride_0, stride_1, padding_0, padding_1, dilation_0, dilation_1, groups, s); +} + +int mlx_conv3d(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int stride_2, int padding_0, int padding_1, int padding_2, int dilation_0, int dilation_1, int dilation_2, int groups, const mlx_stream s) { + return mlx_conv3d_ptr(res, input, weight, stride_0, stride_1, stride_2, padding_0, padding_1, padding_2, dilation_0, dilation_1, dilation_2, groups, s); +} + +int mlx_conv_general(mlx_array* res, const mlx_array input, const mlx_array weight, const int* stride, size_t stride_num, const int* padding_lo, size_t padding_lo_num, const int* padding_hi, size_t padding_hi_num, const int* kernel_dilation, size_t kernel_dilation_num, const int* input_dilation, size_t input_dilation_num, int groups, bool flip, const mlx_stream s) { + return mlx_conv_general_ptr(res, input, weight, stride, stride_num, padding_lo, padding_lo_num, padding_hi, padding_hi_num, kernel_dilation, kernel_dilation_num, input_dilation, input_dilation_num, groups, flip, s); +} + +int mlx_conv_transpose1d(mlx_array* res, const mlx_array input, const mlx_array weight, int stride, int padding, int dilation, int output_padding, int groups, const mlx_stream s) { + return mlx_conv_transpose1d_ptr(res, input, weight, stride, padding, dilation, output_padding, groups, s); +} + +int mlx_conv_transpose2d(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int padding_0, int padding_1, int dilation_0, int dilation_1, int output_padding_0, int output_padding_1, int groups, const mlx_stream s) { + return mlx_conv_transpose2d_ptr(res, input, weight, stride_0, stride_1, padding_0, padding_1, dilation_0, dilation_1, output_padding_0, output_padding_1, groups, s); +} + +int mlx_conv_transpose3d(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int stride_2, int padding_0, int padding_1, int padding_2, int dilation_0, int dilation_1, int dilation_2, int output_padding_0, int output_padding_1, int output_padding_2, int groups, const mlx_stream s) { + return mlx_conv_transpose3d_ptr(res, input, weight, stride_0, stride_1, stride_2, padding_0, padding_1, padding_2, dilation_0, dilation_1, dilation_2, output_padding_0, output_padding_1, output_padding_2, groups, s); +} + +int mlx_copy(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_copy_ptr(res, a, s); +} + +int mlx_cos(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_cos_ptr(res, a, s); +} + +int mlx_cosh(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_cosh_ptr(res, a, s); +} + +int mlx_cummax(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) { + return mlx_cummax_ptr(res, a, axis, reverse, inclusive, s); +} + +int mlx_cummin(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) { + return mlx_cummin_ptr(res, a, axis, reverse, inclusive, s); +} + +int mlx_cumprod(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) { + return mlx_cumprod_ptr(res, a, axis, reverse, inclusive, s); +} + +int mlx_cumsum(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) { + return mlx_cumsum_ptr(res, a, axis, reverse, inclusive, s); +} + +int mlx_degrees(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_degrees_ptr(res, a, s); +} + +int mlx_depends(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies) { + return mlx_depends_ptr(res, inputs, dependencies); +} + +int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s) { + return mlx_dequantize_ptr(res, w, scales, biases, group_size, bits, mode, dtype, s); +} + +int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) { + return mlx_diag_ptr(res, a, k, s); +} + +int mlx_diagonal(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s) { + return mlx_diagonal_ptr(res, a, offset, axis1, axis2, s); +} + +int mlx_divide(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { + return mlx_divide_ptr(res, a, b, s); +} + +int mlx_divmod(mlx_vector_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { + return mlx_divmod_ptr(res, a, b, s); +} + +int mlx_einsum(mlx_array* res, const char* subscripts, const mlx_vector_array operands, const mlx_stream s) { + return mlx_einsum_ptr(res, subscripts, operands, s); +} + +int mlx_equal(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { + return mlx_equal_ptr(res, a, b, s); +} + +int mlx_erf(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_erf_ptr(res, a, s); +} + +int mlx_erfinv(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_erfinv_ptr(res, a, s); +} + +int mlx_exp(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_exp_ptr(res, a, s); +} + +int mlx_expand_dims_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) { + return mlx_expand_dims_axes_ptr(res, a, axes, axes_num, s); +} + +int mlx_expand_dims(mlx_array* res, const mlx_array a, int axis, const mlx_stream s) { + return mlx_expand_dims_ptr(res, a, axis, s); +} + +int mlx_expm1(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_expm1_ptr(res, a, s); +} + +int mlx_eye(mlx_array* res, int n, int m, int k, mlx_dtype dtype, const mlx_stream s) { + return mlx_eye_ptr(res, n, m, k, dtype, s); +} + +int mlx_flatten(mlx_array* res, const mlx_array a, int start_axis, int end_axis, const mlx_stream s) { + return mlx_flatten_ptr(res, a, start_axis, end_axis, s); +} + +int mlx_floor(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_floor_ptr(res, a, s); +} + +int mlx_floor_divide(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { + return mlx_floor_divide_ptr(res, a, b, s); +} + +int mlx_from_fp8(mlx_array* res, const mlx_array x, mlx_dtype dtype, const mlx_stream s) { + return mlx_from_fp8_ptr(res, x, dtype, s); +} + +int mlx_full(mlx_array* res, const int* shape, size_t shape_num, const mlx_array vals, mlx_dtype dtype, const mlx_stream s) { + return mlx_full_ptr(res, shape, shape_num, vals, dtype, s); +} + +int mlx_full_like(mlx_array* res, const mlx_array a, const mlx_array vals, mlx_dtype dtype, const mlx_stream s) { + return mlx_full_like_ptr(res, a, vals, dtype, s); +} + +int mlx_gather(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const int* axes, size_t axes_num, const int* slice_sizes, size_t slice_sizes_num, const mlx_stream s) { + return mlx_gather_ptr(res, a, indices, axes, axes_num, slice_sizes, slice_sizes_num, s); +} + +int mlx_gather_single(mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const int* slice_sizes, size_t slice_sizes_num, const mlx_stream s) { + return mlx_gather_single_ptr(res, a, indices, axis, slice_sizes, slice_sizes_num, s); +} + +int mlx_gather_mm(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_array lhs_indices , const mlx_array rhs_indices , bool sorted_indices, const mlx_stream s) { + return mlx_gather_mm_ptr(res, a, b, lhs_indices, rhs_indices, sorted_indices, s); +} + +int mlx_gather_qmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , const mlx_array lhs_indices , const mlx_array rhs_indices , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, bool sorted_indices, const mlx_stream s) { + return mlx_gather_qmm_ptr(res, x, w, scales, biases, lhs_indices, rhs_indices, transpose, group_size, bits, mode, sorted_indices, s); +} + +int mlx_greater(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { + return mlx_greater_ptr(res, a, b, s); +} + +int mlx_greater_equal(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { + return mlx_greater_equal_ptr(res, a, b, s); +} + +int mlx_hadamard_transform(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s) { + return mlx_hadamard_transform_ptr(res, a, scale, s); +} + +int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) { + return mlx_identity_ptr(res, n, dtype, s); +} + +int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_imag_ptr(res, a, s); +} + +int mlx_inner(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { + return mlx_inner_ptr(res, a, b, s); +} + +int mlx_isclose(mlx_array* res, const mlx_array a, const mlx_array b, double rtol, double atol, bool equal_nan, const mlx_stream s) { + return mlx_isclose_ptr(res, a, b, rtol, atol, equal_nan, s); +} + +int mlx_isfinite(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_isfinite_ptr(res, a, s); +} + +int mlx_isinf(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_isinf_ptr(res, a, s); +} + +int mlx_isnan(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_isnan_ptr(res, a, s); +} + +int mlx_isneginf(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_isneginf_ptr(res, a, s); +} + +int mlx_isposinf(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_isposinf_ptr(res, a, s); +} + +int mlx_kron(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { + return mlx_kron_ptr(res, a, b, s); +} + +int mlx_left_shift(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { + return mlx_left_shift_ptr(res, a, b, s); +} + +int mlx_less(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { + return mlx_less_ptr(res, a, b, s); +} + +int mlx_less_equal(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { + return mlx_less_equal_ptr(res, a, b, s); +} + +int mlx_linspace(mlx_array* res, double start, double stop, int num, mlx_dtype dtype, const mlx_stream s) { + return mlx_linspace_ptr(res, start, stop, num, dtype, s); +} + +int mlx_log(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_log_ptr(res, a, s); +} + +int mlx_log10(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_log10_ptr(res, a, s); +} + +int mlx_log1p(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_log1p_ptr(res, a, s); +} + +int mlx_log2(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_log2_ptr(res, a, s); +} + +int mlx_logaddexp(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { + return mlx_logaddexp_ptr(res, a, b, s); +} + +int mlx_logcumsumexp(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) { + return mlx_logcumsumexp_ptr(res, a, axis, reverse, inclusive, s); +} + +int mlx_logical_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { + return mlx_logical_and_ptr(res, a, b, s); +} + +int mlx_logical_not(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_logical_not_ptr(res, a, s); +} + +int mlx_logical_or(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { + return mlx_logical_or_ptr(res, a, b, s); +} + +int mlx_logsumexp_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) { + return mlx_logsumexp_axes_ptr(res, a, axes, axes_num, keepdims, s); +} + +int mlx_logsumexp_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) { + return mlx_logsumexp_axis_ptr(res, a, axis, keepdims, s); +} + +int mlx_logsumexp(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) { + return mlx_logsumexp_ptr(res, a, keepdims, s); +} + +int mlx_masked_scatter(mlx_array* res, const mlx_array a, const mlx_array mask, const mlx_array src, const mlx_stream s) { + return mlx_masked_scatter_ptr(res, a, mask, src, s); +} + +int mlx_matmul(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { + return mlx_matmul_ptr(res, a, b, s); +} + +int mlx_max_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) { + return mlx_max_axes_ptr(res, a, axes, axes_num, keepdims, s); +} + +int mlx_max_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) { + return mlx_max_axis_ptr(res, a, axis, keepdims, s); +} + +int mlx_max(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) { + return mlx_max_ptr(res, a, keepdims, s); +} + +int mlx_maximum(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { + return mlx_maximum_ptr(res, a, b, s); +} + +int mlx_mean_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) { + return mlx_mean_axes_ptr(res, a, axes, axes_num, keepdims, s); +} + +int mlx_mean_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) { + return mlx_mean_axis_ptr(res, a, axis, keepdims, s); +} + +int mlx_mean(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) { + return mlx_mean_ptr(res, a, keepdims, s); +} + +int mlx_median(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) { + return mlx_median_ptr(res, a, axes, axes_num, keepdims, s); +} + +int mlx_meshgrid(mlx_vector_array* res, const mlx_vector_array arrays, bool sparse, const char* indexing, const mlx_stream s) { + return mlx_meshgrid_ptr(res, arrays, sparse, indexing, s); +} + +int mlx_min_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) { + return mlx_min_axes_ptr(res, a, axes, axes_num, keepdims, s); +} + +int mlx_min_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) { + return mlx_min_axis_ptr(res, a, axis, keepdims, s); +} + +int mlx_min(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) { + return mlx_min_ptr(res, a, keepdims, s); +} + +int mlx_minimum(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { + return mlx_minimum_ptr(res, a, b, s); +} + +int mlx_moveaxis(mlx_array* res, const mlx_array a, int source, int destination, const mlx_stream s) { + return mlx_moveaxis_ptr(res, a, source, destination, s); +} + +int mlx_multiply(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { + return mlx_multiply_ptr(res, a, b, s); +} + +int mlx_nan_to_num(mlx_array* res, const mlx_array a, float nan, mlx_optional_float posinf, mlx_optional_float neginf, const mlx_stream s) { + return mlx_nan_to_num_ptr(res, a, nan, posinf, neginf, s); +} + +int mlx_negative(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_negative_ptr(res, a, s); +} + +int mlx_not_equal(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { + return mlx_not_equal_ptr(res, a, b, s); +} + +int mlx_number_of_elements(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool inverted, mlx_dtype dtype, const mlx_stream s) { + return mlx_number_of_elements_ptr(res, a, axes, axes_num, inverted, dtype, s); +} + +int mlx_ones(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_stream s) { + return mlx_ones_ptr(res, shape, shape_num, dtype, s); +} + +int mlx_ones_like(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_ones_like_ptr(res, a, s); +} + +int mlx_outer(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { + return mlx_outer_ptr(res, a, b, s); +} + +int mlx_pad(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const int* low_pad_size, size_t low_pad_size_num, const int* high_pad_size, size_t high_pad_size_num, const mlx_array pad_value, const char* mode, const mlx_stream s) { + return mlx_pad_ptr(res, a, axes, axes_num, low_pad_size, low_pad_size_num, high_pad_size, high_pad_size_num, pad_value, mode, s); +} + +int mlx_pad_symmetric(mlx_array* res, const mlx_array a, int pad_width, const mlx_array pad_value, const char* mode, const mlx_stream s) { + return mlx_pad_symmetric_ptr(res, a, pad_width, pad_value, mode, s); +} + +int mlx_partition_axis(mlx_array* res, const mlx_array a, int kth, int axis, const mlx_stream s) { + return mlx_partition_axis_ptr(res, a, kth, axis, s); +} + +int mlx_partition(mlx_array* res, const mlx_array a, int kth, const mlx_stream s) { + return mlx_partition_ptr(res, a, kth, s); +} + +int mlx_power(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { + return mlx_power_ptr(res, a, b, s); +} + +int mlx_prod_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) { + return mlx_prod_axes_ptr(res, a, axes, axes_num, keepdims, s); +} + +int mlx_prod_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) { + return mlx_prod_axis_ptr(res, a, axis, keepdims, s); +} + +int mlx_prod(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) { + return mlx_prod_ptr(res, a, keepdims, s); +} + +int mlx_put_along_axis(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s) { + return mlx_put_along_axis_ptr(res, a, indices, values, axis, s); +} + +int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) { + return mlx_qqmm_ptr(res, x, w, w_scales, group_size, bits, mode, s); +} + +int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) { + return mlx_quantize_ptr(res, w, group_size, bits, mode, s); +} + +int mlx_quantized_matmul(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) { + return mlx_quantized_matmul_ptr(res, x, w, scales, biases, transpose, group_size, bits, mode, s); +} + +int mlx_radians(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_radians_ptr(res, a, s); +} + +int mlx_real(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_real_ptr(res, a, s); +} + +int mlx_reciprocal(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_reciprocal_ptr(res, a, s); +} + +int mlx_remainder(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { + return mlx_remainder_ptr(res, a, b, s); +} + +int mlx_repeat_axis(mlx_array* res, const mlx_array arr, int repeats, int axis, const mlx_stream s) { + return mlx_repeat_axis_ptr(res, arr, repeats, axis, s); +} + +int mlx_repeat(mlx_array* res, const mlx_array arr, int repeats, const mlx_stream s) { + return mlx_repeat_ptr(res, arr, repeats, s); +} + +int mlx_reshape(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s) { + return mlx_reshape_ptr(res, a, shape, shape_num, s); +} + +int mlx_right_shift(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { + return mlx_right_shift_ptr(res, a, b, s); +} + +int mlx_roll_axis(mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, int axis, const mlx_stream s) { + return mlx_roll_axis_ptr(res, a, shift, shift_num, axis, s); +} + +int mlx_roll_axes(mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, const int* axes, size_t axes_num, const mlx_stream s) { + return mlx_roll_axes_ptr(res, a, shift, shift_num, axes, axes_num, s); +} + +int mlx_roll(mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, const mlx_stream s) { + return mlx_roll_ptr(res, a, shift, shift_num, s); +} + +int mlx_round(mlx_array* res, const mlx_array a, int decimals, const mlx_stream s) { + return mlx_round_ptr(res, a, decimals, s); +} + +int mlx_rsqrt(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_rsqrt_ptr(res, a, s); +} + +int mlx_scatter(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) { + return mlx_scatter_ptr(res, a, indices, updates, axes, axes_num, s); +} + +int mlx_scatter_single(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) { + return mlx_scatter_single_ptr(res, a, indices, updates, axis, s); +} + +int mlx_scatter_add(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) { + return mlx_scatter_add_ptr(res, a, indices, updates, axes, axes_num, s); +} + +int mlx_scatter_add_single(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) { + return mlx_scatter_add_single_ptr(res, a, indices, updates, axis, s); +} + +int mlx_scatter_add_axis(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s) { + return mlx_scatter_add_axis_ptr(res, a, indices, values, axis, s); +} + +int mlx_scatter_max(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) { + return mlx_scatter_max_ptr(res, a, indices, updates, axes, axes_num, s); +} + +int mlx_scatter_max_single(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) { + return mlx_scatter_max_single_ptr(res, a, indices, updates, axis, s); +} + +int mlx_scatter_min(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) { + return mlx_scatter_min_ptr(res, a, indices, updates, axes, axes_num, s); +} + +int mlx_scatter_min_single(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) { + return mlx_scatter_min_single_ptr(res, a, indices, updates, axis, s); +} + +int mlx_scatter_prod(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) { + return mlx_scatter_prod_ptr(res, a, indices, updates, axes, axes_num, s); +} + +int mlx_scatter_prod_single(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) { + return mlx_scatter_prod_single_ptr(res, a, indices, updates, axis, s); +} + +int mlx_segmented_mm(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_array segments, const mlx_stream s) { + return mlx_segmented_mm_ptr(res, a, b, segments, s); +} + +int mlx_sigmoid(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_sigmoid_ptr(res, a, s); +} + +int mlx_sign(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_sign_ptr(res, a, s); +} + +int mlx_sin(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_sin_ptr(res, a, s); +} + +int mlx_sinh(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_sinh_ptr(res, a, s); +} + +int mlx_slice(mlx_array* res, const mlx_array a, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s) { + return mlx_slice_ptr(res, a, start, start_num, stop, stop_num, strides, strides_num, s); +} + +int mlx_slice_dynamic(mlx_array* res, const mlx_array a, const mlx_array start, const int* axes, size_t axes_num, const int* slice_size, size_t slice_size_num, const mlx_stream s) { + return mlx_slice_dynamic_ptr(res, a, start, axes, axes_num, slice_size, slice_size_num, s); +} + +int mlx_slice_update(mlx_array* res, const mlx_array src, const mlx_array update, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s) { + return mlx_slice_update_ptr(res, src, update, start, start_num, stop, stop_num, strides, strides_num, s); +} + +int mlx_slice_update_dynamic(mlx_array* res, const mlx_array src, const mlx_array update, const mlx_array start, const int* axes, size_t axes_num, const mlx_stream s) { + return mlx_slice_update_dynamic_ptr(res, src, update, start, axes, axes_num, s); +} + +int mlx_softmax_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool precise, const mlx_stream s) { + return mlx_softmax_axes_ptr(res, a, axes, axes_num, precise, s); +} + +int mlx_softmax_axis(mlx_array* res, const mlx_array a, int axis, bool precise, const mlx_stream s) { + return mlx_softmax_axis_ptr(res, a, axis, precise, s); +} + +int mlx_softmax(mlx_array* res, const mlx_array a, bool precise, const mlx_stream s) { + return mlx_softmax_ptr(res, a, precise, s); +} + +int mlx_sort_axis(mlx_array* res, const mlx_array a, int axis, const mlx_stream s) { + return mlx_sort_axis_ptr(res, a, axis, s); +} + +int mlx_sort(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_sort_ptr(res, a, s); +} + +int mlx_split(mlx_vector_array* res, const mlx_array a, int num_splits, int axis, const mlx_stream s) { + return mlx_split_ptr(res, a, num_splits, axis, s); +} + +int mlx_split_sections(mlx_vector_array* res, const mlx_array a, const int* indices, size_t indices_num, int axis, const mlx_stream s) { + return mlx_split_sections_ptr(res, a, indices, indices_num, axis, s); +} + +int mlx_sqrt(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_sqrt_ptr(res, a, s); +} + +int mlx_square(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_square_ptr(res, a, s); +} + +int mlx_squeeze_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) { + return mlx_squeeze_axes_ptr(res, a, axes, axes_num, s); +} + +int mlx_squeeze_axis(mlx_array* res, const mlx_array a, int axis, const mlx_stream s) { + return mlx_squeeze_axis_ptr(res, a, axis, s); +} + +int mlx_squeeze(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_squeeze_ptr(res, a, s); +} + +int mlx_stack_axis(mlx_array* res, const mlx_vector_array arrays, int axis, const mlx_stream s) { + return mlx_stack_axis_ptr(res, arrays, axis, s); +} + +int mlx_stack(mlx_array* res, const mlx_vector_array arrays, const mlx_stream s) { + return mlx_stack_ptr(res, arrays, s); +} + +int mlx_std_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, int ddof, const mlx_stream s) { + return mlx_std_axes_ptr(res, a, axes, axes_num, keepdims, ddof, s); +} + +int mlx_std_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, int ddof, const mlx_stream s) { + return mlx_std_axis_ptr(res, a, axis, keepdims, ddof, s); +} + +int mlx_std(mlx_array* res, const mlx_array a, bool keepdims, int ddof, const mlx_stream s) { + return mlx_std_ptr(res, a, keepdims, ddof, s); +} + +int mlx_stop_gradient(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_stop_gradient_ptr(res, a, s); +} + +int mlx_subtract(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) { + return mlx_subtract_ptr(res, a, b, s); +} + +int mlx_sum_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) { + return mlx_sum_axes_ptr(res, a, axes, axes_num, keepdims, s); +} + +int mlx_sum_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) { + return mlx_sum_axis_ptr(res, a, axis, keepdims, s); +} + +int mlx_sum(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) { + return mlx_sum_ptr(res, a, keepdims, s); +} + +int mlx_swapaxes(mlx_array* res, const mlx_array a, int axis1, int axis2, const mlx_stream s) { + return mlx_swapaxes_ptr(res, a, axis1, axis2, s); +} + +int mlx_take_axis(mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const mlx_stream s) { + return mlx_take_axis_ptr(res, a, indices, axis, s); +} + +int mlx_take(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_stream s) { + return mlx_take_ptr(res, a, indices, s); +} + +int mlx_take_along_axis(mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const mlx_stream s) { + return mlx_take_along_axis_ptr(res, a, indices, axis, s); +} + +int mlx_tan(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_tan_ptr(res, a, s); +} + +int mlx_tanh(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_tanh_ptr(res, a, s); +} + +int mlx_tensordot(mlx_array* res, const mlx_array a, const mlx_array b, const int* axes_a, size_t axes_a_num, const int* axes_b, size_t axes_b_num, const mlx_stream s) { + return mlx_tensordot_ptr(res, a, b, axes_a, axes_a_num, axes_b, axes_b_num, s); +} + +int mlx_tensordot_axis(mlx_array* res, const mlx_array a, const mlx_array b, int axis, const mlx_stream s) { + return mlx_tensordot_axis_ptr(res, a, b, axis, s); +} + +int mlx_tile(mlx_array* res, const mlx_array arr, const int* reps, size_t reps_num, const mlx_stream s) { + return mlx_tile_ptr(res, arr, reps, reps_num, s); +} + +int mlx_to_fp8(mlx_array* res, const mlx_array x, const mlx_stream s) { + return mlx_to_fp8_ptr(res, x, s); +} + +int mlx_topk_axis(mlx_array* res, const mlx_array a, int k, int axis, const mlx_stream s) { + return mlx_topk_axis_ptr(res, a, k, axis, s); +} + +int mlx_topk(mlx_array* res, const mlx_array a, int k, const mlx_stream s) { + return mlx_topk_ptr(res, a, k, s); +} + +int mlx_trace(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, mlx_dtype dtype, const mlx_stream s) { + return mlx_trace_ptr(res, a, offset, axis1, axis2, dtype, s); +} + +int mlx_transpose_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) { + return mlx_transpose_axes_ptr(res, a, axes, axes_num, s); +} + +int mlx_transpose(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_transpose_ptr(res, a, s); +} + +int mlx_tri(mlx_array* res, int n, int m, int k, mlx_dtype type, const mlx_stream s) { + return mlx_tri_ptr(res, n, m, k, type, s); +} + +int mlx_tril(mlx_array* res, const mlx_array x, int k, const mlx_stream s) { + return mlx_tril_ptr(res, x, k, s); +} + +int mlx_triu(mlx_array* res, const mlx_array x, int k, const mlx_stream s) { + return mlx_triu_ptr(res, x, k, s); +} + +int mlx_unflatten(mlx_array* res, const mlx_array a, int axis, const int* shape, size_t shape_num, const mlx_stream s) { + return mlx_unflatten_ptr(res, a, axis, shape, shape_num, s); +} + +int mlx_var_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, int ddof, const mlx_stream s) { + return mlx_var_axes_ptr(res, a, axes, axes_num, keepdims, ddof, s); +} + +int mlx_var_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, int ddof, const mlx_stream s) { + return mlx_var_axis_ptr(res, a, axis, keepdims, ddof, s); +} + +int mlx_var(mlx_array* res, const mlx_array a, bool keepdims, int ddof, const mlx_stream s) { + return mlx_var_ptr(res, a, keepdims, ddof, s); +} + +int mlx_view(mlx_array* res, const mlx_array a, mlx_dtype dtype, const mlx_stream s) { + return mlx_view_ptr(res, a, dtype, s); +} + +int mlx_where(mlx_array* res, const mlx_array condition, const mlx_array x, const mlx_array y, const mlx_stream s) { + return mlx_where_ptr(res, condition, x, y, s); +} + +int mlx_zeros(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_stream s) { + return mlx_zeros_ptr(res, shape, shape_num, dtype, s); +} + +int mlx_zeros_like(mlx_array* res, const mlx_array a, const mlx_stream s) { + return mlx_zeros_like_ptr(res, a, s); +} + +int mlx_random_bernoulli(mlx_array* res, const mlx_array p, const int* shape, size_t shape_num, const mlx_array key , const mlx_stream s) { + return mlx_random_bernoulli_ptr(res, p, shape, shape_num, key, s); +} + +int mlx_random_bits(mlx_array* res, const int* shape, size_t shape_num, int width, const mlx_array key , const mlx_stream s) { + return mlx_random_bits_ptr(res, shape, shape_num, width, key, s); +} + +int mlx_random_categorical_shape(mlx_array* res, const mlx_array logits, int axis, const int* shape, size_t shape_num, const mlx_array key , const mlx_stream s) { + return mlx_random_categorical_shape_ptr(res, logits, axis, shape, shape_num, key, s); +} + +int mlx_random_categorical_num_samples(mlx_array* res, const mlx_array logits_, int axis, int num_samples, const mlx_array key , const mlx_stream s) { + return mlx_random_categorical_num_samples_ptr(res, logits_, axis, num_samples, key, s); +} + +int mlx_random_categorical(mlx_array* res, const mlx_array logits, int axis, const mlx_array key , const mlx_stream s) { + return mlx_random_categorical_ptr(res, logits, axis, key, s); +} + +int mlx_random_gumbel(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) { + return mlx_random_gumbel_ptr(res, shape, shape_num, dtype, key, s); +} + +int mlx_random_key(mlx_array* res, uint64_t seed) { + return mlx_random_key_ptr(res, seed); +} + +int mlx_random_laplace(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, float loc, float scale, const mlx_array key , const mlx_stream s) { + return mlx_random_laplace_ptr(res, shape, shape_num, dtype, loc, scale, key, s); +} + +int mlx_random_multivariate_normal(mlx_array* res, const mlx_array mean, const mlx_array cov, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) { + return mlx_random_multivariate_normal_ptr(res, mean, cov, shape, shape_num, dtype, key, s); +} + +int mlx_random_normal_broadcast(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array loc , const mlx_array scale , const mlx_array key , const mlx_stream s) { + return mlx_random_normal_broadcast_ptr(res, shape, shape_num, dtype, loc, scale, key, s); +} + +int mlx_random_normal(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, float loc, float scale, const mlx_array key , const mlx_stream s) { + return mlx_random_normal_ptr(res, shape, shape_num, dtype, loc, scale, key, s); +} + +int mlx_random_permutation(mlx_array* res, const mlx_array x, int axis, const mlx_array key , const mlx_stream s) { + return mlx_random_permutation_ptr(res, x, axis, key, s); +} + +int mlx_random_permutation_arange(mlx_array* res, int x, const mlx_array key , const mlx_stream s) { + return mlx_random_permutation_arange_ptr(res, x, key, s); +} + +int mlx_random_randint(mlx_array* res, const mlx_array low, const mlx_array high, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) { + return mlx_random_randint_ptr(res, low, high, shape, shape_num, dtype, key, s); +} + +int mlx_random_seed(uint64_t seed) { + return mlx_random_seed_ptr(seed); +} + +int mlx_random_split_num(mlx_array* res, const mlx_array key, int num, const mlx_stream s) { + return mlx_random_split_num_ptr(res, key, num, s); +} + +int mlx_random_split(mlx_array* res_0, mlx_array* res_1, const mlx_array key, const mlx_stream s) { + return mlx_random_split_ptr(res_0, res_1, key, s); +} + +int mlx_random_truncated_normal(mlx_array* res, const mlx_array lower, const mlx_array upper, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) { + return mlx_random_truncated_normal_ptr(res, lower, upper, shape, shape_num, dtype, key, s); +} + +int mlx_random_uniform(mlx_array* res, const mlx_array low, const mlx_array high, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) { + return mlx_random_uniform_ptr(res, low, high, shape, shape_num, dtype, key, s); +} + +mlx_stream mlx_stream_new(void) { + return mlx_stream_new_ptr(); +} + +mlx_stream mlx_stream_new_device(mlx_device dev) { + return mlx_stream_new_device_ptr(dev); +} + +int mlx_stream_set(mlx_stream* stream, const mlx_stream src) { + return mlx_stream_set_ptr(stream, src); +} + +int mlx_stream_free(mlx_stream stream) { + return mlx_stream_free_ptr(stream); +} + +int mlx_stream_tostring(mlx_string* str, mlx_stream stream) { + return mlx_stream_tostring_ptr(str, stream); +} + +bool mlx_stream_equal(mlx_stream lhs, mlx_stream rhs) { + return mlx_stream_equal_ptr(lhs, rhs); +} + +int mlx_stream_get_device(mlx_device* dev, mlx_stream stream) { + return mlx_stream_get_device_ptr(dev, stream); +} + +int mlx_stream_get_index(int* index, mlx_stream stream) { + return mlx_stream_get_index_ptr(index, stream); +} + +int mlx_synchronize(mlx_stream stream) { + return mlx_synchronize_ptr(stream); +} + +int mlx_get_default_stream(mlx_stream* stream, mlx_device dev) { + return mlx_get_default_stream_ptr(stream, dev); +} + +int mlx_set_default_stream(mlx_stream stream) { + return mlx_set_default_stream_ptr(stream); +} + +mlx_stream mlx_default_cpu_stream_new(void) { + return mlx_default_cpu_stream_new_ptr(); +} + +mlx_stream mlx_default_gpu_stream_new(void) { + return mlx_default_gpu_stream_new_ptr(); +} + +mlx_string mlx_string_new(void) { + return mlx_string_new_ptr(); +} + +mlx_string mlx_string_new_data(const char* str) { + return mlx_string_new_data_ptr(str); +} + +int mlx_string_set(mlx_string* str, const mlx_string src) { + return mlx_string_set_ptr(str, src); +} + +const char* mlx_string_data(mlx_string str) { + return mlx_string_data_ptr(str); +} + +int mlx_string_free(mlx_string str) { + return mlx_string_free_ptr(str); +} + +int mlx_async_eval(const mlx_vector_array outputs) { + return mlx_async_eval_ptr(outputs); +} + +int mlx_checkpoint(mlx_closure* res, const mlx_closure fun) { + return mlx_checkpoint_ptr(res, fun); +} + +int mlx_custom_function(mlx_closure* res, const mlx_closure fun, const mlx_closure_custom fun_vjp , const mlx_closure_custom_jvp fun_jvp , const mlx_closure_custom_vmap fun_vmap) { + return mlx_custom_function_ptr(res, fun, fun_vjp, fun_jvp, fun_vmap); +} + +int mlx_custom_vjp(mlx_closure* res, const mlx_closure fun, const mlx_closure_custom fun_vjp) { + return mlx_custom_vjp_ptr(res, fun, fun_vjp); +} + +int mlx_eval(const mlx_vector_array outputs) { + return mlx_eval_ptr(outputs); +} + +int mlx_jvp(mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array primals, const mlx_vector_array tangents) { + return mlx_jvp_ptr(res_0, res_1, fun, primals, tangents); +} + +int mlx_value_and_grad(mlx_closure_value_and_grad* res, const mlx_closure fun, const int* argnums, size_t argnums_num) { + return mlx_value_and_grad_ptr(res, fun, argnums, argnums_num); +} + +int mlx_vjp(mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array primals, const mlx_vector_array cotangents) { + return mlx_vjp_ptr(res_0, res_1, fun, primals, cotangents); +} + +int mlx_detail_vmap_replace(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array s_inputs, const mlx_vector_array s_outputs, const int* in_axes, size_t in_axes_num, const int* out_axes, size_t out_axes_num) { + return mlx_detail_vmap_replace_ptr(res, inputs, s_inputs, s_outputs, in_axes, in_axes_num, out_axes, out_axes_num); +} + +int mlx_detail_vmap_trace(mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array inputs, const int* in_axes, size_t in_axes_num) { + return mlx_detail_vmap_trace_ptr(res_0, res_1, fun, inputs, in_axes, in_axes_num); +} + +mlx_vector_array mlx_vector_array_new(void) { + return mlx_vector_array_new_ptr(); +} + +int mlx_vector_array_set(mlx_vector_array* vec, const mlx_vector_array src) { + return mlx_vector_array_set_ptr(vec, src); +} + +int mlx_vector_array_free(mlx_vector_array vec) { + return mlx_vector_array_free_ptr(vec); +} + +mlx_vector_array mlx_vector_array_new_data(const mlx_array* data, size_t size) { + return mlx_vector_array_new_data_ptr(data, size); +} + +mlx_vector_array mlx_vector_array_new_value(const mlx_array val) { + return mlx_vector_array_new_value_ptr(val); +} + +int mlx_vector_array_set_data(mlx_vector_array* vec, const mlx_array* data, size_t size) { + return mlx_vector_array_set_data_ptr(vec, data, size); +} + +int mlx_vector_array_set_value(mlx_vector_array* vec, const mlx_array val) { + return mlx_vector_array_set_value_ptr(vec, val); +} + +int mlx_vector_array_append_data(mlx_vector_array vec, const mlx_array* data, size_t size) { + return mlx_vector_array_append_data_ptr(vec, data, size); +} + +int mlx_vector_array_append_value(mlx_vector_array vec, const mlx_array val) { + return mlx_vector_array_append_value_ptr(vec, val); +} + +size_t mlx_vector_array_size(mlx_vector_array vec) { + return mlx_vector_array_size_ptr(vec); +} + +int mlx_vector_array_get(mlx_array* res, const mlx_vector_array vec, size_t idx) { + return mlx_vector_array_get_ptr(res, vec, idx); +} + +mlx_vector_vector_array mlx_vector_vector_array_new(void) { + return mlx_vector_vector_array_new_ptr(); +} + +int mlx_vector_vector_array_set(mlx_vector_vector_array* vec, const mlx_vector_vector_array src) { + return mlx_vector_vector_array_set_ptr(vec, src); +} + +int mlx_vector_vector_array_free(mlx_vector_vector_array vec) { + return mlx_vector_vector_array_free_ptr(vec); +} + +mlx_vector_vector_array mlx_vector_vector_array_new_data(const mlx_vector_array* data, size_t size) { + return mlx_vector_vector_array_new_data_ptr(data, size); +} + +mlx_vector_vector_array mlx_vector_vector_array_new_value(const mlx_vector_array val) { + return mlx_vector_vector_array_new_value_ptr(val); +} + +int mlx_vector_vector_array_set_data(mlx_vector_vector_array* vec, const mlx_vector_array* data, size_t size) { + return mlx_vector_vector_array_set_data_ptr(vec, data, size); +} + +int mlx_vector_vector_array_set_value(mlx_vector_vector_array* vec, const mlx_vector_array val) { + return mlx_vector_vector_array_set_value_ptr(vec, val); +} + +int mlx_vector_vector_array_append_data(mlx_vector_vector_array vec, const mlx_vector_array* data, size_t size) { + return mlx_vector_vector_array_append_data_ptr(vec, data, size); +} + +int mlx_vector_vector_array_append_value(mlx_vector_vector_array vec, const mlx_vector_array val) { + return mlx_vector_vector_array_append_value_ptr(vec, val); +} + +size_t mlx_vector_vector_array_size(mlx_vector_vector_array vec) { + return mlx_vector_vector_array_size_ptr(vec); +} + +int mlx_vector_vector_array_get(mlx_vector_array* res, const mlx_vector_vector_array vec, size_t idx) { + return mlx_vector_vector_array_get_ptr(res, vec, idx); +} + +mlx_vector_int mlx_vector_int_new(void) { + return mlx_vector_int_new_ptr(); +} + +int mlx_vector_int_set(mlx_vector_int* vec, const mlx_vector_int src) { + return mlx_vector_int_set_ptr(vec, src); +} + +int mlx_vector_int_free(mlx_vector_int vec) { + return mlx_vector_int_free_ptr(vec); +} + +mlx_vector_int mlx_vector_int_new_data(int* data, size_t size) { + return mlx_vector_int_new_data_ptr(data, size); +} + +mlx_vector_int mlx_vector_int_new_value(int val) { + return mlx_vector_int_new_value_ptr(val); +} + +int mlx_vector_int_set_data(mlx_vector_int* vec, int* data, size_t size) { + return mlx_vector_int_set_data_ptr(vec, data, size); +} + +int mlx_vector_int_set_value(mlx_vector_int* vec, int val) { + return mlx_vector_int_set_value_ptr(vec, val); +} + +int mlx_vector_int_append_data(mlx_vector_int vec, int* data, size_t size) { + return mlx_vector_int_append_data_ptr(vec, data, size); +} + +int mlx_vector_int_append_value(mlx_vector_int vec, int val) { + return mlx_vector_int_append_value_ptr(vec, val); +} + +size_t mlx_vector_int_size(mlx_vector_int vec) { + return mlx_vector_int_size_ptr(vec); +} + +int mlx_vector_int_get(int* res, const mlx_vector_int vec, size_t idx) { + return mlx_vector_int_get_ptr(res, vec, idx); +} + +mlx_vector_string mlx_vector_string_new(void) { + return mlx_vector_string_new_ptr(); +} + +int mlx_vector_string_set(mlx_vector_string* vec, const mlx_vector_string src) { + return mlx_vector_string_set_ptr(vec, src); +} + +int mlx_vector_string_free(mlx_vector_string vec) { + return mlx_vector_string_free_ptr(vec); +} + +mlx_vector_string mlx_vector_string_new_data(const char** data, size_t size) { + return mlx_vector_string_new_data_ptr(data, size); +} + +mlx_vector_string mlx_vector_string_new_value(const char* val) { + return mlx_vector_string_new_value_ptr(val); +} + +int mlx_vector_string_set_data(mlx_vector_string* vec, const char** data, size_t size) { + return mlx_vector_string_set_data_ptr(vec, data, size); +} + +int mlx_vector_string_set_value(mlx_vector_string* vec, const char* val) { + return mlx_vector_string_set_value_ptr(vec, val); +} + +int mlx_vector_string_append_data(mlx_vector_string vec, const char** data, size_t size) { + return mlx_vector_string_append_data_ptr(vec, data, size); +} + +int mlx_vector_string_append_value(mlx_vector_string vec, const char* val) { + return mlx_vector_string_append_value_ptr(vec, val); +} + +size_t mlx_vector_string_size(mlx_vector_string vec) { + return mlx_vector_string_size_ptr(vec); +} + +int mlx_vector_string_get(char** res, const mlx_vector_string vec, size_t idx) { + return mlx_vector_string_get_ptr(res, vec, idx); +} + +int mlx_version(mlx_string* str_) { + return mlx_version_ptr(str_); +} + diff --git a/x/imagegen/mlx/mlx.go b/x/imagegen/mlx/mlx.go index 9cb04e8f2..1ede04cf6 100644 --- a/x/imagegen/mlx/mlx.go +++ b/x/imagegen/mlx/mlx.go @@ -3,12 +3,13 @@ package mlx /* -#cgo CFLAGS: -O3 -I${SRCDIR}/../../../build/_deps/mlx-c-src -#cgo LDFLAGS: -L${SRCDIR}/../../../build/lib/ollama/ -lmlxc -Wl,-rpath,${SRCDIR}/../../../build/lib/ollama/ +#cgo CFLAGS: -O3 -I${SRCDIR}/../../../build/_deps/mlx-c-src -I${SRCDIR} #cgo darwin LDFLAGS: -lc++ -framework Metal -framework Foundation -framework Accelerate -#cgo linux LDFLAGS: -lstdc++ -lcuda -lcudart -lnvrtc +#cgo linux LDFLAGS: -lstdc++ -ldl +#cgo windows LDFLAGS: -lstdc++ -#include "mlx/c/mlx.h" +// Use generated wrappers instead of direct MLX headers +#include "mlx.h" #include #include #include @@ -42,192 +43,6 @@ static inline mlx_stream cpu_stream() { // CGO noescape/nocallback hints to reduce CGO overhead // noescape: pointers won't escape, no heap allocation needed // nocallback: function won't call back into Go -#cgo noescape mlx_add -#cgo nocallback mlx_add -#cgo noescape mlx_subtract -#cgo nocallback mlx_subtract -#cgo noescape mlx_multiply -#cgo nocallback mlx_multiply -#cgo noescape mlx_divide -#cgo nocallback mlx_divide -#cgo noescape mlx_negative -#cgo nocallback mlx_negative -#cgo noescape mlx_abs -#cgo nocallback mlx_abs -#cgo noescape mlx_exp -#cgo nocallback mlx_exp -#cgo noescape mlx_log -#cgo nocallback mlx_log -#cgo noescape mlx_sqrt -#cgo nocallback mlx_sqrt -#cgo noescape mlx_rsqrt -#cgo nocallback mlx_rsqrt -#cgo noescape mlx_square -#cgo nocallback mlx_square -#cgo noescape mlx_power -#cgo nocallback mlx_power -#cgo noescape mlx_erf -#cgo nocallback mlx_erf -#cgo noescape mlx_sigmoid -#cgo nocallback mlx_sigmoid -#cgo noescape mlx_tanh -#cgo nocallback mlx_tanh -#cgo noescape mlx_sin -#cgo nocallback mlx_sin -#cgo noescape mlx_cos -#cgo nocallback mlx_cos -#cgo noescape mlx_maximum -#cgo nocallback mlx_maximum -#cgo noescape mlx_minimum -#cgo nocallback mlx_minimum -#cgo noescape mlx_clip -#cgo nocallback mlx_clip -#cgo noescape mlx_sum -#cgo nocallback mlx_sum -#cgo noescape mlx_sum_axis -#cgo nocallback mlx_sum_axis -#cgo noescape mlx_mean -#cgo nocallback mlx_mean -#cgo noescape mlx_mean_axis -#cgo nocallback mlx_mean_axis -#cgo noescape mlx_var_axis -#cgo nocallback mlx_var_axis -#cgo noescape mlx_argmax -#cgo nocallback mlx_argmax -#cgo noescape mlx_argmax_axis -#cgo nocallback mlx_argmax_axis -#cgo noescape mlx_softmax_axis -#cgo nocallback mlx_softmax_axis -#cgo noescape mlx_cumsum -#cgo nocallback mlx_cumsum -#cgo noescape mlx_matmul -#cgo nocallback mlx_matmul -#cgo noescape mlx_addmm -#cgo nocallback mlx_addmm -#cgo noescape mlx_gather_mm -#cgo nocallback mlx_gather_mm -#cgo noescape mlx_gather_qmm -#cgo nocallback mlx_gather_qmm -#cgo noescape mlx_reshape -#cgo nocallback mlx_reshape -#cgo noescape mlx_transpose_axes -#cgo nocallback mlx_transpose_axes -#cgo noescape mlx_expand_dims -#cgo nocallback mlx_expand_dims -#cgo noescape mlx_squeeze_axis -#cgo nocallback mlx_squeeze_axis -#cgo noescape mlx_flatten -#cgo nocallback mlx_flatten -#cgo noescape mlx_concatenate_axis -#cgo nocallback mlx_concatenate_axis -#cgo noescape mlx_slice -#cgo nocallback mlx_slice -#cgo noescape mlx_slice_update -#cgo nocallback mlx_slice_update -#cgo noescape mlx_as_strided -#cgo nocallback mlx_as_strided -#cgo noescape mlx_view -#cgo nocallback mlx_view -#cgo noescape mlx_contiguous -#cgo nocallback mlx_contiguous -#cgo noescape mlx_pad -#cgo nocallback mlx_pad -#cgo noescape mlx_tile -#cgo nocallback mlx_tile -#cgo noescape mlx_take_axis -#cgo nocallback mlx_take_axis -#cgo noescape mlx_take_along_axis -#cgo nocallback mlx_take_along_axis -#cgo noescape mlx_put_along_axis -#cgo nocallback mlx_put_along_axis -#cgo noescape mlx_where -#cgo nocallback mlx_where -#cgo noescape mlx_argsort_axis -#cgo nocallback mlx_argsort_axis -#cgo noescape mlx_argpartition_axis -#cgo nocallback mlx_argpartition_axis -#cgo noescape mlx_topk_axis -#cgo nocallback mlx_topk_axis -#cgo noescape mlx_less -#cgo nocallback mlx_less -#cgo noescape mlx_greater_equal -#cgo nocallback mlx_greater_equal -#cgo noescape mlx_logical_and -#cgo nocallback mlx_logical_and -#cgo noescape mlx_zeros -#cgo nocallback mlx_zeros -#cgo noescape mlx_zeros_like -#cgo nocallback mlx_zeros_like -#cgo noescape mlx_ones -#cgo nocallback mlx_ones -#cgo noescape mlx_full -#cgo nocallback mlx_full -#cgo noescape mlx_arange -#cgo nocallback mlx_arange -#cgo noescape mlx_linspace -#cgo nocallback mlx_linspace -#cgo noescape mlx_tri -#cgo nocallback mlx_tri -#cgo noescape mlx_astype -#cgo nocallback mlx_astype -#cgo noescape mlx_fast_rms_norm -#cgo nocallback mlx_fast_rms_norm -#cgo noescape mlx_fast_rope -#cgo nocallback mlx_fast_rope -#cgo noescape mlx_fast_scaled_dot_product_attention -#cgo nocallback mlx_fast_scaled_dot_product_attention -#cgo noescape mlx_conv2d -#cgo nocallback mlx_conv2d -#cgo noescape mlx_conv3d -#cgo nocallback mlx_conv3d -#cgo noescape mlx_random_key -#cgo nocallback mlx_random_key -#cgo noescape mlx_random_split -#cgo nocallback mlx_random_split -#cgo noescape mlx_random_categorical_num_samples -#cgo nocallback mlx_random_categorical_num_samples -#cgo noescape mlx_random_normal -#cgo nocallback mlx_random_normal -#cgo noescape mlx_random_uniform -#cgo nocallback mlx_random_uniform -#cgo noescape mlx_array_eval -#cgo nocallback mlx_array_eval -#cgo noescape mlx_eval -#cgo nocallback mlx_eval -#cgo noescape mlx_async_eval -#cgo nocallback mlx_async_eval -#cgo noescape mlx_synchronize -#cgo nocallback mlx_synchronize -#cgo noescape mlx_array_new -#cgo nocallback mlx_array_new -#cgo noescape mlx_array_new_data -#cgo nocallback mlx_array_new_data -#cgo noescape mlx_array_new_float -#cgo nocallback mlx_array_new_float -#cgo noescape mlx_array_free -#cgo nocallback mlx_array_free -#cgo noescape mlx_array_size -#cgo nocallback mlx_array_size -#cgo noescape mlx_array_ndim -#cgo nocallback mlx_array_ndim -#cgo noescape mlx_array_dim -#cgo nocallback mlx_array_dim -#cgo noescape mlx_array_dtype -#cgo nocallback mlx_array_dtype -#cgo noescape mlx_array_item_int32 -#cgo nocallback mlx_array_item_int32 -#cgo noescape mlx_vector_array_new_data -#cgo nocallback mlx_vector_array_new_data -#cgo noescape mlx_vector_array_free -#cgo nocallback mlx_vector_array_free -#cgo noescape mlx_array_new_int -#cgo nocallback mlx_array_new_int -#cgo noescape mlx_stream_new_device -#cgo nocallback mlx_stream_new_device -#cgo noescape mlx_get_default_stream -#cgo nocallback mlx_get_default_stream -#cgo noescape mlx_set_default_stream -#cgo nocallback mlx_set_default_stream */ import "C" import ( @@ -1796,7 +1611,57 @@ func ArgmaxKeepArray(logits *Array) *Array { var RandomState = []*Array{nil} var randomStateMu sync.Mutex +var mlxInitialized bool +var mlxInitError error + +// InitMLX initializes the MLX library by dynamically loading libmlxc. +// This must be called before using any MLX functions. +// Returns an error if the library cannot be loaded. +func InitMLX() error { + if mlxInitialized { + return mlxInitError + } + + // Try to load the MLX dynamic library + ret := C.mlx_dynamic_init() + if ret != 0 { + errMsg := C.GoString(C.mlx_dynamic_error()) + mlxInitError = fmt.Errorf("failed to initialize MLX: %s", errMsg) + return mlxInitError + } + + // Initialize all function pointers via dlsym + handle := C.mlx_get_handle() + ret = C.mlx_load_functions(handle) + if ret != 0 { + mlxInitError = fmt.Errorf("failed to load MLX function symbols") + return mlxInitError + } + + mlxInitialized = true + mlxInitError = nil + return nil +} + +// IsMLXAvailable returns whether MLX was successfully initialized +func IsMLXAvailable() bool { + return mlxInitialized && mlxInitError == nil +} + +// GetMLXInitError returns any error that occurred during MLX initialization +func GetMLXInitError() error { + return mlxInitError +} + func init() { + // Initialize MLX dynamic library first + if err := InitMLX(); err != nil { + // Don't panic in init - let the caller handle the error + // Store the error for later retrieval + mlxInitError = err + return + } + // Lock main goroutine to OS thread for CUDA context stability. // CUDA contexts are bound to threads; Go can migrate goroutines between threads. runtime.LockOSThread() diff --git a/x/imagegen/mlx/mlx.h b/x/imagegen/mlx/mlx.h new file mode 100644 index 000000000..d4ed1a905 --- /dev/null +++ b/x/imagegen/mlx/mlx.h @@ -0,0 +1,2337 @@ +// AUTO-GENERATED by generate_wrappers.go - DO NOT EDIT +// This file provides wrapper declarations for MLX-C functions that use dlopen/dlsym +// +// Strategy: Include MLX-C headers for type definitions, then provide wrapper +// functions that shadow the originals, allowing Go code to call them directly (e.g., C.mlx_add). +// Function pointers are defined in mlx.c (single compilation unit). + +#ifndef MLX_WRAPPERS_H +#define MLX_WRAPPERS_H + +// Include MLX headers for type definitions and original declarations +#include "mlx/c/mlx.h" +#include "mlx_dynamic.h" +#include + +// Undefine any existing MLX function macros +#undef mlx_dtype_size +#undef mlx_array_tostring +#undef mlx_array_new +#undef mlx_array_free +#undef mlx_array_new_bool +#undef mlx_array_new_int +#undef mlx_array_new_float32 +#undef mlx_array_new_float +#undef mlx_array_new_float64 +#undef mlx_array_new_double +#undef mlx_array_new_complex +#undef mlx_array_new_data +#undef mlx_array_set +#undef mlx_array_set_bool +#undef mlx_array_set_int +#undef mlx_array_set_float32 +#undef mlx_array_set_float +#undef mlx_array_set_float64 +#undef mlx_array_set_double +#undef mlx_array_set_complex +#undef mlx_array_set_data +#undef mlx_array_itemsize +#undef mlx_array_size +#undef mlx_array_nbytes +#undef mlx_array_ndim +#undef mlx_array_shape +#undef mlx_array_strides +#undef mlx_array_dim +#undef mlx_array_dtype +#undef mlx_array_eval +#undef mlx_array_item_bool +#undef mlx_array_item_uint8 +#undef mlx_array_item_uint16 +#undef mlx_array_item_uint32 +#undef mlx_array_item_uint64 +#undef mlx_array_item_int8 +#undef mlx_array_item_int16 +#undef mlx_array_item_int32 +#undef mlx_array_item_int64 +#undef mlx_array_item_float32 +#undef mlx_array_item_float64 +#undef mlx_array_item_complex64 +#undef mlx_array_item_float16 +#undef mlx_array_item_bfloat16 +#undef mlx_array_data_bool +#undef mlx_array_data_uint8 +#undef mlx_array_data_uint16 +#undef mlx_array_data_uint32 +#undef mlx_array_data_uint64 +#undef mlx_array_data_int8 +#undef mlx_array_data_int16 +#undef mlx_array_data_int32 +#undef mlx_array_data_int64 +#undef mlx_array_data_float32 +#undef mlx_array_data_float64 +#undef mlx_array_data_complex64 +#undef mlx_array_data_float16 +#undef mlx_array_data_bfloat16 +#undef _mlx_array_is_available +#undef _mlx_array_wait +#undef _mlx_array_is_contiguous +#undef _mlx_array_is_row_contiguous +#undef _mlx_array_is_col_contiguous +#undef mlx_closure_new +#undef mlx_closure_free +#undef mlx_closure_new_func +#undef mlx_closure_new_func_payload +#undef mlx_closure_set +#undef mlx_closure_apply +#undef mlx_closure_new_unary +#undef mlx_closure_kwargs_new +#undef mlx_closure_kwargs_free +#undef mlx_closure_kwargs_new_func +#undef mlx_closure_kwargs_new_func_payload +#undef mlx_closure_kwargs_set +#undef mlx_closure_kwargs_apply +#undef mlx_closure_value_and_grad_new +#undef mlx_closure_value_and_grad_free +#undef mlx_closure_value_and_grad_new_func +#undef mlx_closure_value_and_grad_new_func_payload +#undef mlx_closure_value_and_grad_set +#undef mlx_closure_value_and_grad_apply +#undef mlx_closure_custom_new +#undef mlx_closure_custom_free +#undef mlx_closure_custom_new_func +#undef mlx_closure_custom_new_func_payload +#undef mlx_closure_custom_set +#undef mlx_closure_custom_apply +#undef mlx_closure_custom_jvp_new +#undef mlx_closure_custom_jvp_free +#undef mlx_closure_custom_jvp_new_func +#undef mlx_closure_custom_jvp_new_func_payload +#undef mlx_closure_custom_jvp_set +#undef mlx_closure_custom_jvp_apply +#undef mlx_closure_custom_vmap_new +#undef mlx_closure_custom_vmap_free +#undef mlx_closure_custom_vmap_new_func +#undef mlx_closure_custom_vmap_new_func_payload +#undef mlx_closure_custom_vmap_set +#undef mlx_closure_custom_vmap_apply +#undef mlx_compile +#undef mlx_detail_compile +#undef mlx_detail_compile_clear_cache +#undef mlx_detail_compile_erase +#undef mlx_disable_compile +#undef mlx_enable_compile +#undef mlx_set_compile_mode +#undef mlx_device_new +#undef mlx_device_new_type +#undef mlx_device_free +#undef mlx_device_set +#undef mlx_device_tostring +#undef mlx_device_equal +#undef mlx_device_get_index +#undef mlx_device_get_type +#undef mlx_get_default_device +#undef mlx_set_default_device +#undef mlx_distributed_all_gather +#undef mlx_distributed_all_max +#undef mlx_distributed_all_min +#undef mlx_distributed_all_sum +#undef mlx_distributed_recv +#undef mlx_distributed_recv_like +#undef mlx_distributed_send +#undef mlx_distributed_sum_scatter +#undef mlx_distributed_group_rank +#undef mlx_distributed_group_size +#undef mlx_distributed_group_split +#undef mlx_distributed_is_available +#undef mlx_distributed_init +#undef mlx_set_error_handler +#undef _mlx_error +#undef mlx_export_function +#undef mlx_export_function_kwargs +#undef mlx_function_exporter_new +#undef mlx_function_exporter_free +#undef mlx_function_exporter_apply +#undef mlx_function_exporter_apply_kwargs +#undef mlx_imported_function_new +#undef mlx_imported_function_free +#undef mlx_imported_function_apply +#undef mlx_imported_function_apply_kwargs +#undef mlx_fast_cuda_kernel_config_new +#undef mlx_fast_cuda_kernel_config_free +#undef mlx_fast_cuda_kernel_config_add_output_arg +#undef mlx_fast_cuda_kernel_config_set_grid +#undef mlx_fast_cuda_kernel_config_set_thread_group +#undef mlx_fast_cuda_kernel_config_set_init_value +#undef mlx_fast_cuda_kernel_config_set_verbose +#undef mlx_fast_cuda_kernel_config_add_template_arg_dtype +#undef mlx_fast_cuda_kernel_config_add_template_arg_int +#undef mlx_fast_cuda_kernel_config_add_template_arg_bool +#undef mlx_fast_cuda_kernel_new +#undef mlx_fast_cuda_kernel_free +#undef mlx_fast_cuda_kernel_apply +#undef mlx_fast_layer_norm +#undef mlx_fast_metal_kernel_config_new +#undef mlx_fast_metal_kernel_config_free +#undef mlx_fast_metal_kernel_config_add_output_arg +#undef mlx_fast_metal_kernel_config_set_grid +#undef mlx_fast_metal_kernel_config_set_thread_group +#undef mlx_fast_metal_kernel_config_set_init_value +#undef mlx_fast_metal_kernel_config_set_verbose +#undef mlx_fast_metal_kernel_config_add_template_arg_dtype +#undef mlx_fast_metal_kernel_config_add_template_arg_int +#undef mlx_fast_metal_kernel_config_add_template_arg_bool +#undef mlx_fast_metal_kernel_new +#undef mlx_fast_metal_kernel_free +#undef mlx_fast_metal_kernel_apply +#undef mlx_fast_rms_norm +#undef mlx_fast_rope +#undef mlx_fast_rope_dynamic +#undef mlx_fast_scaled_dot_product_attention +#undef mlx_fft_fft +#undef mlx_fft_fft2 +#undef mlx_fft_fftn +#undef mlx_fft_fftshift +#undef mlx_fft_ifft +#undef mlx_fft_ifft2 +#undef mlx_fft_ifftn +#undef mlx_fft_ifftshift +#undef mlx_fft_irfft +#undef mlx_fft_irfft2 +#undef mlx_fft_irfftn +#undef mlx_fft_rfft +#undef mlx_fft_rfft2 +#undef mlx_fft_rfftn +#undef mlx_load_reader +#undef mlx_load +#undef mlx_load_safetensors_reader +#undef mlx_load_safetensors +#undef mlx_save_writer +#undef mlx_save +#undef mlx_save_safetensors_writer +#undef mlx_save_safetensors +#undef mlx_io_reader_new +#undef mlx_io_reader_descriptor +#undef mlx_io_reader_tostring +#undef mlx_io_reader_free +#undef mlx_io_writer_new +#undef mlx_io_writer_descriptor +#undef mlx_io_writer_tostring +#undef mlx_io_writer_free +#undef mlx_linalg_cholesky +#undef mlx_linalg_cholesky_inv +#undef mlx_linalg_cross +#undef mlx_linalg_eig +#undef mlx_linalg_eigh +#undef mlx_linalg_eigvals +#undef mlx_linalg_eigvalsh +#undef mlx_linalg_inv +#undef mlx_linalg_lu +#undef mlx_linalg_lu_factor +#undef mlx_linalg_norm +#undef mlx_linalg_norm_matrix +#undef mlx_linalg_norm_l2 +#undef mlx_linalg_pinv +#undef mlx_linalg_qr +#undef mlx_linalg_solve +#undef mlx_linalg_solve_triangular +#undef mlx_linalg_svd +#undef mlx_linalg_tri_inv +#undef mlx_map_string_to_array_new +#undef mlx_map_string_to_array_set +#undef mlx_map_string_to_array_free +#undef mlx_map_string_to_array_insert +#undef mlx_map_string_to_array_get +#undef mlx_map_string_to_array_iterator_new +#undef mlx_map_string_to_array_iterator_free +#undef mlx_map_string_to_array_iterator_next +#undef mlx_map_string_to_string_new +#undef mlx_map_string_to_string_set +#undef mlx_map_string_to_string_free +#undef mlx_map_string_to_string_insert +#undef mlx_map_string_to_string_get +#undef mlx_map_string_to_string_iterator_new +#undef mlx_map_string_to_string_iterator_free +#undef mlx_map_string_to_string_iterator_next +#undef mlx_clear_cache +#undef mlx_get_active_memory +#undef mlx_get_cache_memory +#undef mlx_get_memory_limit +#undef mlx_get_peak_memory +#undef mlx_reset_peak_memory +#undef mlx_set_cache_limit +#undef mlx_set_memory_limit +#undef mlx_set_wired_limit +#undef mlx_metal_device_info +#undef mlx_metal_is_available +#undef mlx_metal_start_capture +#undef mlx_metal_stop_capture +#undef mlx_abs +#undef mlx_add +#undef mlx_addmm +#undef mlx_all_axes +#undef mlx_all_axis +#undef mlx_all +#undef mlx_allclose +#undef mlx_any_axes +#undef mlx_any_axis +#undef mlx_any +#undef mlx_arange +#undef mlx_arccos +#undef mlx_arccosh +#undef mlx_arcsin +#undef mlx_arcsinh +#undef mlx_arctan +#undef mlx_arctan2 +#undef mlx_arctanh +#undef mlx_argmax_axis +#undef mlx_argmax +#undef mlx_argmin_axis +#undef mlx_argmin +#undef mlx_argpartition_axis +#undef mlx_argpartition +#undef mlx_argsort_axis +#undef mlx_argsort +#undef mlx_array_equal +#undef mlx_as_strided +#undef mlx_astype +#undef mlx_atleast_1d +#undef mlx_atleast_2d +#undef mlx_atleast_3d +#undef mlx_bitwise_and +#undef mlx_bitwise_invert +#undef mlx_bitwise_or +#undef mlx_bitwise_xor +#undef mlx_block_masked_mm +#undef mlx_broadcast_arrays +#undef mlx_broadcast_to +#undef mlx_ceil +#undef mlx_clip +#undef mlx_concatenate_axis +#undef mlx_concatenate +#undef mlx_conjugate +#undef mlx_contiguous +#undef mlx_conv1d +#undef mlx_conv2d +#undef mlx_conv3d +#undef mlx_conv_general +#undef mlx_conv_transpose1d +#undef mlx_conv_transpose2d +#undef mlx_conv_transpose3d +#undef mlx_copy +#undef mlx_cos +#undef mlx_cosh +#undef mlx_cummax +#undef mlx_cummin +#undef mlx_cumprod +#undef mlx_cumsum +#undef mlx_degrees +#undef mlx_depends +#undef mlx_dequantize +#undef mlx_diag +#undef mlx_diagonal +#undef mlx_divide +#undef mlx_divmod +#undef mlx_einsum +#undef mlx_equal +#undef mlx_erf +#undef mlx_erfinv +#undef mlx_exp +#undef mlx_expand_dims_axes +#undef mlx_expand_dims +#undef mlx_expm1 +#undef mlx_eye +#undef mlx_flatten +#undef mlx_floor +#undef mlx_floor_divide +#undef mlx_from_fp8 +#undef mlx_full +#undef mlx_full_like +#undef mlx_gather +#undef mlx_gather_single +#undef mlx_gather_mm +#undef mlx_gather_qmm +#undef mlx_greater +#undef mlx_greater_equal +#undef mlx_hadamard_transform +#undef mlx_identity +#undef mlx_imag +#undef mlx_inner +#undef mlx_isclose +#undef mlx_isfinite +#undef mlx_isinf +#undef mlx_isnan +#undef mlx_isneginf +#undef mlx_isposinf +#undef mlx_kron +#undef mlx_left_shift +#undef mlx_less +#undef mlx_less_equal +#undef mlx_linspace +#undef mlx_log +#undef mlx_log10 +#undef mlx_log1p +#undef mlx_log2 +#undef mlx_logaddexp +#undef mlx_logcumsumexp +#undef mlx_logical_and +#undef mlx_logical_not +#undef mlx_logical_or +#undef mlx_logsumexp_axes +#undef mlx_logsumexp_axis +#undef mlx_logsumexp +#undef mlx_masked_scatter +#undef mlx_matmul +#undef mlx_max_axes +#undef mlx_max_axis +#undef mlx_max +#undef mlx_maximum +#undef mlx_mean_axes +#undef mlx_mean_axis +#undef mlx_mean +#undef mlx_median +#undef mlx_meshgrid +#undef mlx_min_axes +#undef mlx_min_axis +#undef mlx_min +#undef mlx_minimum +#undef mlx_moveaxis +#undef mlx_multiply +#undef mlx_nan_to_num +#undef mlx_negative +#undef mlx_not_equal +#undef mlx_number_of_elements +#undef mlx_ones +#undef mlx_ones_like +#undef mlx_outer +#undef mlx_pad +#undef mlx_pad_symmetric +#undef mlx_partition_axis +#undef mlx_partition +#undef mlx_power +#undef mlx_prod_axes +#undef mlx_prod_axis +#undef mlx_prod +#undef mlx_put_along_axis +#undef mlx_qqmm +#undef mlx_quantize +#undef mlx_quantized_matmul +#undef mlx_radians +#undef mlx_real +#undef mlx_reciprocal +#undef mlx_remainder +#undef mlx_repeat_axis +#undef mlx_repeat +#undef mlx_reshape +#undef mlx_right_shift +#undef mlx_roll_axis +#undef mlx_roll_axes +#undef mlx_roll +#undef mlx_round +#undef mlx_rsqrt +#undef mlx_scatter +#undef mlx_scatter_single +#undef mlx_scatter_add +#undef mlx_scatter_add_single +#undef mlx_scatter_add_axis +#undef mlx_scatter_max +#undef mlx_scatter_max_single +#undef mlx_scatter_min +#undef mlx_scatter_min_single +#undef mlx_scatter_prod +#undef mlx_scatter_prod_single +#undef mlx_segmented_mm +#undef mlx_sigmoid +#undef mlx_sign +#undef mlx_sin +#undef mlx_sinh +#undef mlx_slice +#undef mlx_slice_dynamic +#undef mlx_slice_update +#undef mlx_slice_update_dynamic +#undef mlx_softmax_axes +#undef mlx_softmax_axis +#undef mlx_softmax +#undef mlx_sort_axis +#undef mlx_sort +#undef mlx_split +#undef mlx_split_sections +#undef mlx_sqrt +#undef mlx_square +#undef mlx_squeeze_axes +#undef mlx_squeeze_axis +#undef mlx_squeeze +#undef mlx_stack_axis +#undef mlx_stack +#undef mlx_std_axes +#undef mlx_std_axis +#undef mlx_std +#undef mlx_stop_gradient +#undef mlx_subtract +#undef mlx_sum_axes +#undef mlx_sum_axis +#undef mlx_sum +#undef mlx_swapaxes +#undef mlx_take_axis +#undef mlx_take +#undef mlx_take_along_axis +#undef mlx_tan +#undef mlx_tanh +#undef mlx_tensordot +#undef mlx_tensordot_axis +#undef mlx_tile +#undef mlx_to_fp8 +#undef mlx_topk_axis +#undef mlx_topk +#undef mlx_trace +#undef mlx_transpose_axes +#undef mlx_transpose +#undef mlx_tri +#undef mlx_tril +#undef mlx_triu +#undef mlx_unflatten +#undef mlx_var_axes +#undef mlx_var_axis +#undef mlx_var +#undef mlx_view +#undef mlx_where +#undef mlx_zeros +#undef mlx_zeros_like +#undef mlx_random_bernoulli +#undef mlx_random_bits +#undef mlx_random_categorical_shape +#undef mlx_random_categorical_num_samples +#undef mlx_random_categorical +#undef mlx_random_gumbel +#undef mlx_random_key +#undef mlx_random_laplace +#undef mlx_random_multivariate_normal +#undef mlx_random_normal_broadcast +#undef mlx_random_normal +#undef mlx_random_permutation +#undef mlx_random_permutation_arange +#undef mlx_random_randint +#undef mlx_random_seed +#undef mlx_random_split_num +#undef mlx_random_split +#undef mlx_random_truncated_normal +#undef mlx_random_uniform +#undef mlx_stream_new +#undef mlx_stream_new_device +#undef mlx_stream_set +#undef mlx_stream_free +#undef mlx_stream_tostring +#undef mlx_stream_equal +#undef mlx_stream_get_device +#undef mlx_stream_get_index +#undef mlx_synchronize +#undef mlx_get_default_stream +#undef mlx_set_default_stream +#undef mlx_default_cpu_stream_new +#undef mlx_default_gpu_stream_new +#undef mlx_string_new +#undef mlx_string_new_data +#undef mlx_string_set +#undef mlx_string_data +#undef mlx_string_free +#undef mlx_async_eval +#undef mlx_checkpoint +#undef mlx_custom_function +#undef mlx_custom_vjp +#undef mlx_eval +#undef mlx_jvp +#undef mlx_value_and_grad +#undef mlx_vjp +#undef mlx_detail_vmap_replace +#undef mlx_detail_vmap_trace +#undef mlx_vector_array_new +#undef mlx_vector_array_set +#undef mlx_vector_array_free +#undef mlx_vector_array_new_data +#undef mlx_vector_array_new_value +#undef mlx_vector_array_set_data +#undef mlx_vector_array_set_value +#undef mlx_vector_array_append_data +#undef mlx_vector_array_append_value +#undef mlx_vector_array_size +#undef mlx_vector_array_get +#undef mlx_vector_vector_array_new +#undef mlx_vector_vector_array_set +#undef mlx_vector_vector_array_free +#undef mlx_vector_vector_array_new_data +#undef mlx_vector_vector_array_new_value +#undef mlx_vector_vector_array_set_data +#undef mlx_vector_vector_array_set_value +#undef mlx_vector_vector_array_append_data +#undef mlx_vector_vector_array_append_value +#undef mlx_vector_vector_array_size +#undef mlx_vector_vector_array_get +#undef mlx_vector_int_new +#undef mlx_vector_int_set +#undef mlx_vector_int_free +#undef mlx_vector_int_new_data +#undef mlx_vector_int_new_value +#undef mlx_vector_int_set_data +#undef mlx_vector_int_set_value +#undef mlx_vector_int_append_data +#undef mlx_vector_int_append_value +#undef mlx_vector_int_size +#undef mlx_vector_int_get +#undef mlx_vector_string_new +#undef mlx_vector_string_set +#undef mlx_vector_string_free +#undef mlx_vector_string_new_data +#undef mlx_vector_string_new_value +#undef mlx_vector_string_set_data +#undef mlx_vector_string_set_value +#undef mlx_vector_string_append_data +#undef mlx_vector_string_append_value +#undef mlx_vector_string_size +#undef mlx_vector_string_get +#undef mlx_version + +// Function pointer declarations (defined in mlx.c, loaded via dlsym) +extern size_t (*mlx_dtype_size_ptr)(mlx_dtype dtype); +extern int (*mlx_array_tostring_ptr)(mlx_string* str, const mlx_array arr); +extern mlx_array (*mlx_array_new_ptr)(void); +extern int (*mlx_array_free_ptr)(mlx_array arr); +extern mlx_array (*mlx_array_new_bool_ptr)(bool val); +extern mlx_array (*mlx_array_new_int_ptr)(int val); +extern mlx_array (*mlx_array_new_float32_ptr)(float val); +extern mlx_array (*mlx_array_new_float_ptr)(float val); +extern mlx_array (*mlx_array_new_float64_ptr)(double val); +extern mlx_array (*mlx_array_new_double_ptr)(double val); +extern mlx_array (*mlx_array_new_complex_ptr)(float real_val, float imag_val); +extern mlx_array (*mlx_array_new_data_ptr)(const void* data, const int* shape, int dim, mlx_dtype dtype); +extern int (*mlx_array_set_ptr)(mlx_array* arr, const mlx_array src); +extern int (*mlx_array_set_bool_ptr)(mlx_array* arr, bool val); +extern int (*mlx_array_set_int_ptr)(mlx_array* arr, int val); +extern int (*mlx_array_set_float32_ptr)(mlx_array* arr, float val); +extern int (*mlx_array_set_float_ptr)(mlx_array* arr, float val); +extern int (*mlx_array_set_float64_ptr)(mlx_array* arr, double val); +extern int (*mlx_array_set_double_ptr)(mlx_array* arr, double val); +extern int (*mlx_array_set_complex_ptr)(mlx_array* arr, float real_val, float imag_val); +extern int (*mlx_array_set_data_ptr)(mlx_array* arr, const void* data, const int* shape, int dim, mlx_dtype dtype); +extern size_t (*mlx_array_itemsize_ptr)(const mlx_array arr); +extern size_t (*mlx_array_size_ptr)(const mlx_array arr); +extern size_t (*mlx_array_nbytes_ptr)(const mlx_array arr); +extern size_t (*mlx_array_ndim_ptr)(const mlx_array arr); +extern const int* (*mlx_array_shape_ptr)(const mlx_array arr); +extern const size_t* (*mlx_array_strides_ptr)(const mlx_array arr); +extern int (*mlx_array_dim_ptr)(const mlx_array arr, int dim); +extern mlx_dtype (*mlx_array_dtype_ptr)(const mlx_array arr); +extern int (*mlx_array_eval_ptr)(mlx_array arr); +extern int (*mlx_array_item_bool_ptr)(bool* res, const mlx_array arr); +extern int (*mlx_array_item_uint8_ptr)(uint8_t* res, const mlx_array arr); +extern int (*mlx_array_item_uint16_ptr)(uint16_t* res, const mlx_array arr); +extern int (*mlx_array_item_uint32_ptr)(uint32_t* res, const mlx_array arr); +extern int (*mlx_array_item_uint64_ptr)(uint64_t* res, const mlx_array arr); +extern int (*mlx_array_item_int8_ptr)(int8_t* res, const mlx_array arr); +extern int (*mlx_array_item_int16_ptr)(int16_t* res, const mlx_array arr); +extern int (*mlx_array_item_int32_ptr)(int32_t* res, const mlx_array arr); +extern int (*mlx_array_item_int64_ptr)(int64_t* res, const mlx_array arr); +extern int (*mlx_array_item_float32_ptr)(float* res, const mlx_array arr); +extern int (*mlx_array_item_float64_ptr)(double* res, const mlx_array arr); +extern int (*mlx_array_item_complex64_ptr)(float _Complex* res, const mlx_array arr); +#if defined(__aarch64__) || defined(_M_ARM64) +extern int (*mlx_array_item_float16_ptr)(float16_t* res, const mlx_array arr); +#endif +#if defined(__aarch64__) || defined(_M_ARM64) +extern int (*mlx_array_item_bfloat16_ptr)(bfloat16_t* res, const mlx_array arr); +#endif +extern const bool* (*mlx_array_data_bool_ptr)(const mlx_array arr); +extern const uint8_t* (*mlx_array_data_uint8_ptr)(const mlx_array arr); +extern const uint16_t* (*mlx_array_data_uint16_ptr)(const mlx_array arr); +extern const uint32_t* (*mlx_array_data_uint32_ptr)(const mlx_array arr); +extern const uint64_t* (*mlx_array_data_uint64_ptr)(const mlx_array arr); +extern const int8_t* (*mlx_array_data_int8_ptr)(const mlx_array arr); +extern const int16_t* (*mlx_array_data_int16_ptr)(const mlx_array arr); +extern const int32_t* (*mlx_array_data_int32_ptr)(const mlx_array arr); +extern const int64_t* (*mlx_array_data_int64_ptr)(const mlx_array arr); +extern const float* (*mlx_array_data_float32_ptr)(const mlx_array arr); +extern const double* (*mlx_array_data_float64_ptr)(const mlx_array arr); +extern const float _Complex* (*mlx_array_data_complex64_ptr)(const mlx_array arr); +#if defined(__aarch64__) || defined(_M_ARM64) +extern const float16_t* (*mlx_array_data_float16_ptr)(const mlx_array arr); +#endif +#if defined(__aarch64__) || defined(_M_ARM64) +extern const bfloat16_t* (*mlx_array_data_bfloat16_ptr)(const mlx_array arr); +#endif +extern int (*_mlx_array_is_available_ptr)(bool* res, const mlx_array arr); +extern int (*_mlx_array_wait_ptr)(const mlx_array arr); +extern int (*_mlx_array_is_contiguous_ptr)(bool* res, const mlx_array arr); +extern int (*_mlx_array_is_row_contiguous_ptr)(bool* res, const mlx_array arr); +extern int (*_mlx_array_is_col_contiguous_ptr)(bool* res, const mlx_array arr); +extern mlx_closure (*mlx_closure_new_ptr)(void); +extern int (*mlx_closure_free_ptr)(mlx_closure cls); +extern mlx_closure (*mlx_closure_new_func_ptr)(int (*fun)(mlx_vector_array*, const mlx_vector_array)); +extern mlx_closure (*mlx_closure_new_func_payload_ptr)(int (*fun)(mlx_vector_array*, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)); +extern int (*mlx_closure_set_ptr)(mlx_closure* cls, const mlx_closure src); +extern int (*mlx_closure_apply_ptr)(mlx_vector_array* res, mlx_closure cls, const mlx_vector_array input); +extern mlx_closure (*mlx_closure_new_unary_ptr)(int (*fun)(mlx_array*, const mlx_array)); +extern mlx_closure_kwargs (*mlx_closure_kwargs_new_ptr)(void); +extern int (*mlx_closure_kwargs_free_ptr)(mlx_closure_kwargs cls); +extern mlx_closure_kwargs (*mlx_closure_kwargs_new_func_ptr)(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_map_string_to_array)); +extern mlx_closure_kwargs (*mlx_closure_kwargs_new_func_payload_ptr)(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_map_string_to_array, void*), void* payload, void (*dtor)(void*)); +extern int (*mlx_closure_kwargs_set_ptr)(mlx_closure_kwargs* cls, const mlx_closure_kwargs src); +extern int (*mlx_closure_kwargs_apply_ptr)(mlx_vector_array* res, mlx_closure_kwargs cls, const mlx_vector_array input_0, const mlx_map_string_to_array input_1); +extern mlx_closure_value_and_grad (*mlx_closure_value_and_grad_new_ptr)(void); +extern int (*mlx_closure_value_and_grad_free_ptr)(mlx_closure_value_and_grad cls); +extern mlx_closure_value_and_grad (*mlx_closure_value_and_grad_new_func_ptr)(int (*fun)(mlx_vector_array*, mlx_vector_array*, const mlx_vector_array)); +extern mlx_closure_value_and_grad (*mlx_closure_value_and_grad_new_func_payload_ptr)(int (*fun)( mlx_vector_array*, mlx_vector_array*, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)); +extern int (*mlx_closure_value_and_grad_set_ptr)(mlx_closure_value_and_grad* cls, const mlx_closure_value_and_grad src); +extern int (*mlx_closure_value_and_grad_apply_ptr)(mlx_vector_array* res_0, mlx_vector_array* res_1, mlx_closure_value_and_grad cls, const mlx_vector_array input); +extern mlx_closure_custom (*mlx_closure_custom_new_ptr)(void); +extern int (*mlx_closure_custom_free_ptr)(mlx_closure_custom cls); +extern mlx_closure_custom (*mlx_closure_custom_new_func_ptr)(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const mlx_vector_array)); +extern mlx_closure_custom (*mlx_closure_custom_new_func_payload_ptr)(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)); +extern int (*mlx_closure_custom_set_ptr)(mlx_closure_custom* cls, const mlx_closure_custom src); +extern int (*mlx_closure_custom_apply_ptr)(mlx_vector_array* res, mlx_closure_custom cls, const mlx_vector_array input_0, const mlx_vector_array input_1, const mlx_vector_array input_2); +extern mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_ptr)(void); +extern int (*mlx_closure_custom_jvp_free_ptr)(mlx_closure_custom_jvp cls); +extern mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_ptr)(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const int*, size_t _num)); +extern mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_payload_ptr)(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const int*, size_t _num, void*), void* payload, void (*dtor)(void*)); +extern int (*mlx_closure_custom_jvp_set_ptr)(mlx_closure_custom_jvp* cls, const mlx_closure_custom_jvp src); +extern int (*mlx_closure_custom_jvp_apply_ptr)(mlx_vector_array* res, mlx_closure_custom_jvp cls, const mlx_vector_array input_0, const mlx_vector_array input_1, const int* input_2, size_t input_2_num); +extern mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_ptr)(void); +extern int (*mlx_closure_custom_vmap_free_ptr)(mlx_closure_custom_vmap cls); +extern mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_ptr)(int (*fun)( mlx_vector_array*, mlx_vector_int*, const mlx_vector_array, const int*, size_t _num)); +extern mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_payload_ptr)(int (*fun)( mlx_vector_array*, mlx_vector_int*, const mlx_vector_array, const int*, size_t _num, void*), void* payload, void (*dtor)(void*)); +extern int (*mlx_closure_custom_vmap_set_ptr)(mlx_closure_custom_vmap* cls, const mlx_closure_custom_vmap src); +extern int (*mlx_closure_custom_vmap_apply_ptr)(mlx_vector_array* res_0, mlx_vector_int* res_1, mlx_closure_custom_vmap cls, const mlx_vector_array input_0, const int* input_1, size_t input_1_num); +extern int (*mlx_compile_ptr)(mlx_closure* res, const mlx_closure fun, bool shapeless); +extern int (*mlx_detail_compile_ptr)(mlx_closure* res, const mlx_closure fun, uintptr_t fun_id, bool shapeless, const uint64_t* constants, size_t constants_num); +extern int (*mlx_detail_compile_clear_cache_ptr)(void); +extern int (*mlx_detail_compile_erase_ptr)(uintptr_t fun_id); +extern int (*mlx_disable_compile_ptr)(void); +extern int (*mlx_enable_compile_ptr)(void); +extern int (*mlx_set_compile_mode_ptr)(mlx_compile_mode mode); +extern mlx_device (*mlx_device_new_ptr)(void); +extern mlx_device (*mlx_device_new_type_ptr)(mlx_device_type type, int index); +extern int (*mlx_device_free_ptr)(mlx_device dev); +extern int (*mlx_device_set_ptr)(mlx_device* dev, const mlx_device src); +extern int (*mlx_device_tostring_ptr)(mlx_string* str, mlx_device dev); +extern bool (*mlx_device_equal_ptr)(mlx_device lhs, mlx_device rhs); +extern int (*mlx_device_get_index_ptr)(int* index, mlx_device dev); +extern int (*mlx_device_get_type_ptr)(mlx_device_type* type, mlx_device dev); +extern int (*mlx_get_default_device_ptr)(mlx_device* dev); +extern int (*mlx_set_default_device_ptr)(mlx_device dev); +extern int (*mlx_distributed_all_gather_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S); +extern int (*mlx_distributed_all_max_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s); +extern int (*mlx_distributed_all_min_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s); +extern int (*mlx_distributed_all_sum_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s); +extern int (*mlx_distributed_recv_ptr)(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, int src, const mlx_distributed_group group , const mlx_stream s); +extern int (*mlx_distributed_recv_like_ptr)(mlx_array* res, const mlx_array x, int src, const mlx_distributed_group group , const mlx_stream s); +extern int (*mlx_distributed_send_ptr)(mlx_array* res, const mlx_array x, int dst, const mlx_distributed_group group , const mlx_stream s); +extern int (*mlx_distributed_sum_scatter_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s); +extern int (*mlx_distributed_group_rank_ptr)(mlx_distributed_group group); +extern int (*mlx_distributed_group_size_ptr)(mlx_distributed_group group); +extern mlx_distributed_group (*mlx_distributed_group_split_ptr)(mlx_distributed_group group, int color, int key); +extern bool (*mlx_distributed_is_available_ptr)(void); +extern mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict); +extern void (*mlx_set_error_handler_ptr)(mlx_error_handler_func handler, void* data, void (*dtor)(void*)); +extern void (*_mlx_error_ptr)(const char* file, const int line, const char* fmt, ...); +extern int (*mlx_export_function_ptr)(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless); +extern int (*mlx_export_function_kwargs_ptr)(const char* file, const mlx_closure_kwargs fun, const mlx_vector_array args, const mlx_map_string_to_array kwargs, bool shapeless); +extern mlx_function_exporter (*mlx_function_exporter_new_ptr)(const char* file, const mlx_closure fun, bool shapeless); +extern int (*mlx_function_exporter_free_ptr)(mlx_function_exporter xfunc); +extern int (*mlx_function_exporter_apply_ptr)(const mlx_function_exporter xfunc, const mlx_vector_array args); +extern int (*mlx_function_exporter_apply_kwargs_ptr)(const mlx_function_exporter xfunc, const mlx_vector_array args, const mlx_map_string_to_array kwargs); +extern mlx_imported_function (*mlx_imported_function_new_ptr)(const char* file); +extern int (*mlx_imported_function_free_ptr)(mlx_imported_function xfunc); +extern int (*mlx_imported_function_apply_ptr)(mlx_vector_array* res, const mlx_imported_function xfunc, const mlx_vector_array args); +extern int (*mlx_imported_function_apply_kwargs_ptr)(mlx_vector_array* res, const mlx_imported_function xfunc, const mlx_vector_array args, const mlx_map_string_to_array kwargs); +extern mlx_fast_cuda_kernel_config (*mlx_fast_cuda_kernel_config_new_ptr)(void); +extern void (*mlx_fast_cuda_kernel_config_free_ptr)(mlx_fast_cuda_kernel_config cls); +extern int (*mlx_fast_cuda_kernel_config_add_output_arg_ptr)(mlx_fast_cuda_kernel_config cls, const int* shape, size_t size, mlx_dtype dtype); +extern int (*mlx_fast_cuda_kernel_config_set_grid_ptr)(mlx_fast_cuda_kernel_config cls, int grid1, int grid2, int grid3); +extern int (*mlx_fast_cuda_kernel_config_set_thread_group_ptr)(mlx_fast_cuda_kernel_config cls, int thread1, int thread2, int thread3); +extern int (*mlx_fast_cuda_kernel_config_set_init_value_ptr)(mlx_fast_cuda_kernel_config cls, float value); +extern int (*mlx_fast_cuda_kernel_config_set_verbose_ptr)(mlx_fast_cuda_kernel_config cls, bool verbose); +extern int (*mlx_fast_cuda_kernel_config_add_template_arg_dtype_ptr)(mlx_fast_cuda_kernel_config cls, const char* name, mlx_dtype dtype); +extern int (*mlx_fast_cuda_kernel_config_add_template_arg_int_ptr)(mlx_fast_cuda_kernel_config cls, const char* name, int value); +extern int (*mlx_fast_cuda_kernel_config_add_template_arg_bool_ptr)(mlx_fast_cuda_kernel_config cls, const char* name, bool value); +extern mlx_fast_cuda_kernel (*mlx_fast_cuda_kernel_new_ptr)(const char* name, const mlx_vector_string input_names, const mlx_vector_string output_names, const char* source, const char* header, bool ensure_row_contiguous, int shared_memory); +extern void (*mlx_fast_cuda_kernel_free_ptr)(mlx_fast_cuda_kernel cls); +extern int (*mlx_fast_cuda_kernel_apply_ptr)(mlx_vector_array* outputs, mlx_fast_cuda_kernel cls, const mlx_vector_array inputs, const mlx_fast_cuda_kernel_config config, const mlx_stream stream); +extern int (*mlx_fast_layer_norm_ptr)(mlx_array* res, const mlx_array x, const mlx_array weight , const mlx_array bias , float eps, const mlx_stream s); +extern mlx_fast_metal_kernel_config (*mlx_fast_metal_kernel_config_new_ptr)(void); +extern void (*mlx_fast_metal_kernel_config_free_ptr)(mlx_fast_metal_kernel_config cls); +extern int (*mlx_fast_metal_kernel_config_add_output_arg_ptr)(mlx_fast_metal_kernel_config cls, const int* shape, size_t size, mlx_dtype dtype); +extern int (*mlx_fast_metal_kernel_config_set_grid_ptr)(mlx_fast_metal_kernel_config cls, int grid1, int grid2, int grid3); +extern int (*mlx_fast_metal_kernel_config_set_thread_group_ptr)(mlx_fast_metal_kernel_config cls, int thread1, int thread2, int thread3); +extern int (*mlx_fast_metal_kernel_config_set_init_value_ptr)(mlx_fast_metal_kernel_config cls, float value); +extern int (*mlx_fast_metal_kernel_config_set_verbose_ptr)(mlx_fast_metal_kernel_config cls, bool verbose); +extern int (*mlx_fast_metal_kernel_config_add_template_arg_dtype_ptr)(mlx_fast_metal_kernel_config cls, const char* name, mlx_dtype dtype); +extern int (*mlx_fast_metal_kernel_config_add_template_arg_int_ptr)(mlx_fast_metal_kernel_config cls, const char* name, int value); +extern int (*mlx_fast_metal_kernel_config_add_template_arg_bool_ptr)(mlx_fast_metal_kernel_config cls, const char* name, bool value); +extern mlx_fast_metal_kernel (*mlx_fast_metal_kernel_new_ptr)(const char* name, const mlx_vector_string input_names, const mlx_vector_string output_names, const char* source, const char* header, bool ensure_row_contiguous, bool atomic_outputs); +extern void (*mlx_fast_metal_kernel_free_ptr)(mlx_fast_metal_kernel cls); +extern int (*mlx_fast_metal_kernel_apply_ptr)(mlx_vector_array* outputs, mlx_fast_metal_kernel cls, const mlx_vector_array inputs, const mlx_fast_metal_kernel_config config, const mlx_stream stream); +extern int (*mlx_fast_rms_norm_ptr)(mlx_array* res, const mlx_array x, const mlx_array weight , float eps, const mlx_stream s); +extern int (*mlx_fast_rope_ptr)(mlx_array* res, const mlx_array x, int dims, bool traditional, mlx_optional_float base, float scale, int offset, const mlx_array freqs , const mlx_stream s); +extern int (*mlx_fast_rope_dynamic_ptr)(mlx_array* res, const mlx_array x, int dims, bool traditional, mlx_optional_float base, float scale, const mlx_array offset, const mlx_array freqs , const mlx_stream s); +extern int (*mlx_fast_scaled_dot_product_attention_ptr)(mlx_array* res, const mlx_array queries, const mlx_array keys, const mlx_array values, float scale, const char* mask_mode, const mlx_array mask_arr , const mlx_array sinks , const mlx_stream s); +extern int (*mlx_fft_fft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s); +extern int (*mlx_fft_fft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); +extern int (*mlx_fft_fftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); +extern int (*mlx_fft_fftshift_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s); +extern int (*mlx_fft_ifft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s); +extern int (*mlx_fft_ifft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); +extern int (*mlx_fft_ifftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); +extern int (*mlx_fft_ifftshift_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s); +extern int (*mlx_fft_irfft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s); +extern int (*mlx_fft_irfft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); +extern int (*mlx_fft_irfftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); +extern int (*mlx_fft_rfft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s); +extern int (*mlx_fft_rfft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); +extern int (*mlx_fft_rfftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); +extern int (*mlx_load_reader_ptr)(mlx_array* res, mlx_io_reader in_stream, const mlx_stream s); +extern int (*mlx_load_ptr)(mlx_array* res, const char* file, const mlx_stream s); +extern int (*mlx_load_safetensors_reader_ptr)(mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, mlx_io_reader in_stream, const mlx_stream s); +extern int (*mlx_load_safetensors_ptr)(mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, const char* file, const mlx_stream s); +extern int (*mlx_save_writer_ptr)(mlx_io_writer out_stream, const mlx_array a); +extern int (*mlx_save_ptr)(const char* file, const mlx_array a); +extern int (*mlx_save_safetensors_writer_ptr)(mlx_io_writer in_stream, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata); +extern int (*mlx_save_safetensors_ptr)(const char* file, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata); +extern mlx_io_reader (*mlx_io_reader_new_ptr)(void* desc, mlx_io_vtable vtable); +extern int (*mlx_io_reader_descriptor_ptr)(void** desc_, mlx_io_reader io); +extern int (*mlx_io_reader_tostring_ptr)(mlx_string* str_, mlx_io_reader io); +extern int (*mlx_io_reader_free_ptr)(mlx_io_reader io); +extern mlx_io_writer (*mlx_io_writer_new_ptr)(void* desc, mlx_io_vtable vtable); +extern int (*mlx_io_writer_descriptor_ptr)(void** desc_, mlx_io_writer io); +extern int (*mlx_io_writer_tostring_ptr)(mlx_string* str_, mlx_io_writer io); +extern int (*mlx_io_writer_free_ptr)(mlx_io_writer io); +extern int (*mlx_linalg_cholesky_ptr)(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s); +extern int (*mlx_linalg_cholesky_inv_ptr)(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s); +extern int (*mlx_linalg_cross_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int axis, const mlx_stream s); +extern int (*mlx_linalg_eig_ptr)(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s); +extern int (*mlx_linalg_eigh_ptr)(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const char* UPLO, const mlx_stream s); +extern int (*mlx_linalg_eigvals_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_linalg_eigvalsh_ptr)(mlx_array* res, const mlx_array a, const char* UPLO, const mlx_stream s); +extern int (*mlx_linalg_inv_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_linalg_lu_ptr)(mlx_vector_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_linalg_lu_factor_ptr)(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s); +extern int (*mlx_linalg_norm_ptr)(mlx_array* res, const mlx_array a, double ord, const int* axis , size_t axis_num, bool keepdims, const mlx_stream s); +extern int (*mlx_linalg_norm_matrix_ptr)(mlx_array* res, const mlx_array a, const char* ord, const int* axis , size_t axis_num, bool keepdims, const mlx_stream s); +extern int (*mlx_linalg_norm_l2_ptr)(mlx_array* res, const mlx_array a, const int* axis , size_t axis_num, bool keepdims, const mlx_stream s); +extern int (*mlx_linalg_pinv_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_linalg_qr_ptr)(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s); +extern int (*mlx_linalg_solve_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); +extern int (*mlx_linalg_solve_triangular_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, bool upper, const mlx_stream s); +extern int (*mlx_linalg_svd_ptr)(mlx_vector_array* res, const mlx_array a, bool compute_uv, const mlx_stream s); +extern int (*mlx_linalg_tri_inv_ptr)(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s); +extern mlx_map_string_to_array (*mlx_map_string_to_array_new_ptr)(void); +extern int (*mlx_map_string_to_array_set_ptr)(mlx_map_string_to_array* map, const mlx_map_string_to_array src); +extern int (*mlx_map_string_to_array_free_ptr)(mlx_map_string_to_array map); +extern int (*mlx_map_string_to_array_insert_ptr)(mlx_map_string_to_array map, const char* key, const mlx_array value); +extern int (*mlx_map_string_to_array_get_ptr)(mlx_array* value, const mlx_map_string_to_array map, const char* key); +extern mlx_map_string_to_array_iterator (*mlx_map_string_to_array_iterator_new_ptr)(mlx_map_string_to_array map); +extern int (*mlx_map_string_to_array_iterator_free_ptr)(mlx_map_string_to_array_iterator it); +extern int (*mlx_map_string_to_array_iterator_next_ptr)(const char** key, mlx_array* value, mlx_map_string_to_array_iterator it); +extern mlx_map_string_to_string (*mlx_map_string_to_string_new_ptr)(void); +extern int (*mlx_map_string_to_string_set_ptr)(mlx_map_string_to_string* map, const mlx_map_string_to_string src); +extern int (*mlx_map_string_to_string_free_ptr)(mlx_map_string_to_string map); +extern int (*mlx_map_string_to_string_insert_ptr)(mlx_map_string_to_string map, const char* key, const char* value); +extern int (*mlx_map_string_to_string_get_ptr)(const char** value, const mlx_map_string_to_string map, const char* key); +extern mlx_map_string_to_string_iterator (*mlx_map_string_to_string_iterator_new_ptr)(mlx_map_string_to_string map); +extern int (*mlx_map_string_to_string_iterator_free_ptr)(mlx_map_string_to_string_iterator it); +extern int (*mlx_map_string_to_string_iterator_next_ptr)(const char** key, const char** value, mlx_map_string_to_string_iterator it); +extern int (*mlx_clear_cache_ptr)(void); +extern int (*mlx_get_active_memory_ptr)(size_t* res); +extern int (*mlx_get_cache_memory_ptr)(size_t* res); +extern int (*mlx_get_memory_limit_ptr)(size_t* res); +extern int (*mlx_get_peak_memory_ptr)(size_t* res); +extern int (*mlx_reset_peak_memory_ptr)(void); +extern int (*mlx_set_cache_limit_ptr)(size_t* res, size_t limit); +extern int (*mlx_set_memory_limit_ptr)(size_t* res, size_t limit); +extern int (*mlx_set_wired_limit_ptr)(size_t* res, size_t limit); +extern mlx_metal_device_info_t (*mlx_metal_device_info_ptr)(void); +extern int (*mlx_metal_is_available_ptr)(bool* res); +extern int (*mlx_metal_start_capture_ptr)(const char* path); +extern int (*mlx_metal_stop_capture_ptr)(void); +extern int (*mlx_abs_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_add_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); +extern int (*mlx_addmm_ptr)(mlx_array* res, const mlx_array c, const mlx_array a, const mlx_array b, float alpha, float beta, const mlx_stream s); +extern int (*mlx_all_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s); +extern int (*mlx_all_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s); +extern int (*mlx_all_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s); +extern int (*mlx_allclose_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, double rtol, double atol, bool equal_nan, const mlx_stream s); +extern int (*mlx_any_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s); +extern int (*mlx_any_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s); +extern int (*mlx_any_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s); +extern int (*mlx_arange_ptr)(mlx_array* res, double start, double stop, double step, mlx_dtype dtype, const mlx_stream s); +extern int (*mlx_arccos_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_arccosh_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_arcsin_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_arcsinh_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_arctan_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_arctan2_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); +extern int (*mlx_arctanh_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_argmax_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s); +extern int (*mlx_argmax_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s); +extern int (*mlx_argmin_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s); +extern int (*mlx_argmin_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s); +extern int (*mlx_argpartition_axis_ptr)(mlx_array* res, const mlx_array a, int kth, int axis, const mlx_stream s); +extern int (*mlx_argpartition_ptr)(mlx_array* res, const mlx_array a, int kth, const mlx_stream s); +extern int (*mlx_argsort_axis_ptr)(mlx_array* res, const mlx_array a, int axis, const mlx_stream s); +extern int (*mlx_argsort_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_array_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, bool equal_nan, const mlx_stream s); +extern int (*mlx_as_strided_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const int64_t* strides, size_t strides_num, size_t offset, const mlx_stream s); +extern int (*mlx_astype_ptr)(mlx_array* res, const mlx_array a, mlx_dtype dtype, const mlx_stream s); +extern int (*mlx_atleast_1d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_atleast_2d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_atleast_3d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_bitwise_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); +extern int (*mlx_bitwise_invert_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_bitwise_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); +extern int (*mlx_bitwise_xor_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); +extern int (*mlx_block_masked_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s); +extern int (*mlx_broadcast_arrays_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s); +extern int (*mlx_broadcast_to_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s); +extern int (*mlx_ceil_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_clip_ptr)(mlx_array* res, const mlx_array a, const mlx_array a_min , const mlx_array a_max , const mlx_stream s); +extern int (*mlx_concatenate_axis_ptr)(mlx_array* res, const mlx_vector_array arrays, int axis, const mlx_stream s); +extern int (*mlx_concatenate_ptr)(mlx_array* res, const mlx_vector_array arrays, const mlx_stream s); +extern int (*mlx_conjugate_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_contiguous_ptr)(mlx_array* res, const mlx_array a, bool allow_col_major, const mlx_stream s); +extern int (*mlx_conv1d_ptr)(mlx_array* res, const mlx_array input, const mlx_array weight, int stride, int padding, int dilation, int groups, const mlx_stream s); +extern int (*mlx_conv2d_ptr)(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int padding_0, int padding_1, int dilation_0, int dilation_1, int groups, const mlx_stream s); +extern int (*mlx_conv3d_ptr)(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int stride_2, int padding_0, int padding_1, int padding_2, int dilation_0, int dilation_1, int dilation_2, int groups, const mlx_stream s); +extern int (*mlx_conv_general_ptr)(mlx_array* res, const mlx_array input, const mlx_array weight, const int* stride, size_t stride_num, const int* padding_lo, size_t padding_lo_num, const int* padding_hi, size_t padding_hi_num, const int* kernel_dilation, size_t kernel_dilation_num, const int* input_dilation, size_t input_dilation_num, int groups, bool flip, const mlx_stream s); +extern int (*mlx_conv_transpose1d_ptr)(mlx_array* res, const mlx_array input, const mlx_array weight, int stride, int padding, int dilation, int output_padding, int groups, const mlx_stream s); +extern int (*mlx_conv_transpose2d_ptr)(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int padding_0, int padding_1, int dilation_0, int dilation_1, int output_padding_0, int output_padding_1, int groups, const mlx_stream s); +extern int (*mlx_conv_transpose3d_ptr)(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int stride_2, int padding_0, int padding_1, int padding_2, int dilation_0, int dilation_1, int dilation_2, int output_padding_0, int output_padding_1, int output_padding_2, int groups, const mlx_stream s); +extern int (*mlx_copy_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_cos_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_cosh_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_cummax_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s); +extern int (*mlx_cummin_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s); +extern int (*mlx_cumprod_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s); +extern int (*mlx_cumsum_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s); +extern int (*mlx_degrees_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_depends_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies); +extern int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s); +extern int (*mlx_diag_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s); +extern int (*mlx_diagonal_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s); +extern int (*mlx_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); +extern int (*mlx_divmod_ptr)(mlx_vector_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); +extern int (*mlx_einsum_ptr)(mlx_array* res, const char* subscripts, const mlx_vector_array operands, const mlx_stream s); +extern int (*mlx_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); +extern int (*mlx_erf_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_erfinv_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_exp_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_expand_dims_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s); +extern int (*mlx_expand_dims_ptr)(mlx_array* res, const mlx_array a, int axis, const mlx_stream s); +extern int (*mlx_expm1_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_eye_ptr)(mlx_array* res, int n, int m, int k, mlx_dtype dtype, const mlx_stream s); +extern int (*mlx_flatten_ptr)(mlx_array* res, const mlx_array a, int start_axis, int end_axis, const mlx_stream s); +extern int (*mlx_floor_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_floor_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); +extern int (*mlx_from_fp8_ptr)(mlx_array* res, const mlx_array x, mlx_dtype dtype, const mlx_stream s); +extern int (*mlx_full_ptr)(mlx_array* res, const int* shape, size_t shape_num, const mlx_array vals, mlx_dtype dtype, const mlx_stream s); +extern int (*mlx_full_like_ptr)(mlx_array* res, const mlx_array a, const mlx_array vals, mlx_dtype dtype, const mlx_stream s); +extern int (*mlx_gather_ptr)(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const int* axes, size_t axes_num, const int* slice_sizes, size_t slice_sizes_num, const mlx_stream s); +extern int (*mlx_gather_single_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const int* slice_sizes, size_t slice_sizes_num, const mlx_stream s); +extern int (*mlx_gather_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_array lhs_indices , const mlx_array rhs_indices , bool sorted_indices, const mlx_stream s); +extern int (*mlx_gather_qmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , const mlx_array lhs_indices , const mlx_array rhs_indices , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, bool sorted_indices, const mlx_stream s); +extern int (*mlx_greater_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); +extern int (*mlx_greater_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); +extern int (*mlx_hadamard_transform_ptr)(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s); +extern int (*mlx_identity_ptr)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s); +extern int (*mlx_imag_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_inner_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); +extern int (*mlx_isclose_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, double rtol, double atol, bool equal_nan, const mlx_stream s); +extern int (*mlx_isfinite_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_isinf_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_isnan_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_isneginf_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_isposinf_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_kron_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); +extern int (*mlx_left_shift_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); +extern int (*mlx_less_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); +extern int (*mlx_less_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); +extern int (*mlx_linspace_ptr)(mlx_array* res, double start, double stop, int num, mlx_dtype dtype, const mlx_stream s); +extern int (*mlx_log_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_log10_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_log1p_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_log2_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_logaddexp_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); +extern int (*mlx_logcumsumexp_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s); +extern int (*mlx_logical_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); +extern int (*mlx_logical_not_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_logical_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); +extern int (*mlx_logsumexp_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s); +extern int (*mlx_logsumexp_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s); +extern int (*mlx_logsumexp_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s); +extern int (*mlx_masked_scatter_ptr)(mlx_array* res, const mlx_array a, const mlx_array mask, const mlx_array src, const mlx_stream s); +extern int (*mlx_matmul_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); +extern int (*mlx_max_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s); +extern int (*mlx_max_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s); +extern int (*mlx_max_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s); +extern int (*mlx_maximum_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); +extern int (*mlx_mean_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s); +extern int (*mlx_mean_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s); +extern int (*mlx_mean_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s); +extern int (*mlx_median_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s); +extern int (*mlx_meshgrid_ptr)(mlx_vector_array* res, const mlx_vector_array arrays, bool sparse, const char* indexing, const mlx_stream s); +extern int (*mlx_min_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s); +extern int (*mlx_min_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s); +extern int (*mlx_min_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s); +extern int (*mlx_minimum_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); +extern int (*mlx_moveaxis_ptr)(mlx_array* res, const mlx_array a, int source, int destination, const mlx_stream s); +extern int (*mlx_multiply_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); +extern int (*mlx_nan_to_num_ptr)(mlx_array* res, const mlx_array a, float nan, mlx_optional_float posinf, mlx_optional_float neginf, const mlx_stream s); +extern int (*mlx_negative_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_not_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); +extern int (*mlx_number_of_elements_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool inverted, mlx_dtype dtype, const mlx_stream s); +extern int (*mlx_ones_ptr)(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_stream s); +extern int (*mlx_ones_like_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_outer_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); +extern int (*mlx_pad_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const int* low_pad_size, size_t low_pad_size_num, const int* high_pad_size, size_t high_pad_size_num, const mlx_array pad_value, const char* mode, const mlx_stream s); +extern int (*mlx_pad_symmetric_ptr)(mlx_array* res, const mlx_array a, int pad_width, const mlx_array pad_value, const char* mode, const mlx_stream s); +extern int (*mlx_partition_axis_ptr)(mlx_array* res, const mlx_array a, int kth, int axis, const mlx_stream s); +extern int (*mlx_partition_ptr)(mlx_array* res, const mlx_array a, int kth, const mlx_stream s); +extern int (*mlx_power_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); +extern int (*mlx_prod_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s); +extern int (*mlx_prod_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s); +extern int (*mlx_prod_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s); +extern int (*mlx_put_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s); +extern int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s); +extern int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s); +extern int (*mlx_quantized_matmul_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s); +extern int (*mlx_radians_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_real_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_reciprocal_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_remainder_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); +extern int (*mlx_repeat_axis_ptr)(mlx_array* res, const mlx_array arr, int repeats, int axis, const mlx_stream s); +extern int (*mlx_repeat_ptr)(mlx_array* res, const mlx_array arr, int repeats, const mlx_stream s); +extern int (*mlx_reshape_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s); +extern int (*mlx_right_shift_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); +extern int (*mlx_roll_axis_ptr)(mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, int axis, const mlx_stream s); +extern int (*mlx_roll_axes_ptr)(mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, const int* axes, size_t axes_num, const mlx_stream s); +extern int (*mlx_roll_ptr)(mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, const mlx_stream s); +extern int (*mlx_round_ptr)(mlx_array* res, const mlx_array a, int decimals, const mlx_stream s); +extern int (*mlx_rsqrt_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_scatter_ptr)(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s); +extern int (*mlx_scatter_single_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s); +extern int (*mlx_scatter_add_ptr)(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s); +extern int (*mlx_scatter_add_single_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s); +extern int (*mlx_scatter_add_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s); +extern int (*mlx_scatter_max_ptr)(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s); +extern int (*mlx_scatter_max_single_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s); +extern int (*mlx_scatter_min_ptr)(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s); +extern int (*mlx_scatter_min_single_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s); +extern int (*mlx_scatter_prod_ptr)(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s); +extern int (*mlx_scatter_prod_single_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s); +extern int (*mlx_segmented_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_array segments, const mlx_stream s); +extern int (*mlx_sigmoid_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_sign_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_sin_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_sinh_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_slice_ptr)(mlx_array* res, const mlx_array a, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s); +extern int (*mlx_slice_dynamic_ptr)(mlx_array* res, const mlx_array a, const mlx_array start, const int* axes, size_t axes_num, const int* slice_size, size_t slice_size_num, const mlx_stream s); +extern int (*mlx_slice_update_ptr)(mlx_array* res, const mlx_array src, const mlx_array update, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s); +extern int (*mlx_slice_update_dynamic_ptr)(mlx_array* res, const mlx_array src, const mlx_array update, const mlx_array start, const int* axes, size_t axes_num, const mlx_stream s); +extern int (*mlx_softmax_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool precise, const mlx_stream s); +extern int (*mlx_softmax_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool precise, const mlx_stream s); +extern int (*mlx_softmax_ptr)(mlx_array* res, const mlx_array a, bool precise, const mlx_stream s); +extern int (*mlx_sort_axis_ptr)(mlx_array* res, const mlx_array a, int axis, const mlx_stream s); +extern int (*mlx_sort_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_split_ptr)(mlx_vector_array* res, const mlx_array a, int num_splits, int axis, const mlx_stream s); +extern int (*mlx_split_sections_ptr)(mlx_vector_array* res, const mlx_array a, const int* indices, size_t indices_num, int axis, const mlx_stream s); +extern int (*mlx_sqrt_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_square_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_squeeze_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s); +extern int (*mlx_squeeze_axis_ptr)(mlx_array* res, const mlx_array a, int axis, const mlx_stream s); +extern int (*mlx_squeeze_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_stack_axis_ptr)(mlx_array* res, const mlx_vector_array arrays, int axis, const mlx_stream s); +extern int (*mlx_stack_ptr)(mlx_array* res, const mlx_vector_array arrays, const mlx_stream s); +extern int (*mlx_std_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, int ddof, const mlx_stream s); +extern int (*mlx_std_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, int ddof, const mlx_stream s); +extern int (*mlx_std_ptr)(mlx_array* res, const mlx_array a, bool keepdims, int ddof, const mlx_stream s); +extern int (*mlx_stop_gradient_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_subtract_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); +extern int (*mlx_sum_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s); +extern int (*mlx_sum_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s); +extern int (*mlx_sum_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s); +extern int (*mlx_swapaxes_ptr)(mlx_array* res, const mlx_array a, int axis1, int axis2, const mlx_stream s); +extern int (*mlx_take_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const mlx_stream s); +extern int (*mlx_take_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_stream s); +extern int (*mlx_take_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const mlx_stream s); +extern int (*mlx_tan_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_tanh_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_tensordot_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const int* axes_a, size_t axes_a_num, const int* axes_b, size_t axes_b_num, const mlx_stream s); +extern int (*mlx_tensordot_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int axis, const mlx_stream s); +extern int (*mlx_tile_ptr)(mlx_array* res, const mlx_array arr, const int* reps, size_t reps_num, const mlx_stream s); +extern int (*mlx_to_fp8_ptr)(mlx_array* res, const mlx_array x, const mlx_stream s); +extern int (*mlx_topk_axis_ptr)(mlx_array* res, const mlx_array a, int k, int axis, const mlx_stream s); +extern int (*mlx_topk_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s); +extern int (*mlx_trace_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, mlx_dtype dtype, const mlx_stream s); +extern int (*mlx_transpose_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s); +extern int (*mlx_transpose_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_tri_ptr)(mlx_array* res, int n, int m, int k, mlx_dtype type, const mlx_stream s); +extern int (*mlx_tril_ptr)(mlx_array* res, const mlx_array x, int k, const mlx_stream s); +extern int (*mlx_triu_ptr)(mlx_array* res, const mlx_array x, int k, const mlx_stream s); +extern int (*mlx_unflatten_ptr)(mlx_array* res, const mlx_array a, int axis, const int* shape, size_t shape_num, const mlx_stream s); +extern int (*mlx_var_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, int ddof, const mlx_stream s); +extern int (*mlx_var_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, int ddof, const mlx_stream s); +extern int (*mlx_var_ptr)(mlx_array* res, const mlx_array a, bool keepdims, int ddof, const mlx_stream s); +extern int (*mlx_view_ptr)(mlx_array* res, const mlx_array a, mlx_dtype dtype, const mlx_stream s); +extern int (*mlx_where_ptr)(mlx_array* res, const mlx_array condition, const mlx_array x, const mlx_array y, const mlx_stream s); +extern int (*mlx_zeros_ptr)(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_stream s); +extern int (*mlx_zeros_like_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s); +extern int (*mlx_random_bernoulli_ptr)(mlx_array* res, const mlx_array p, const int* shape, size_t shape_num, const mlx_array key , const mlx_stream s); +extern int (*mlx_random_bits_ptr)(mlx_array* res, const int* shape, size_t shape_num, int width, const mlx_array key , const mlx_stream s); +extern int (*mlx_random_categorical_shape_ptr)(mlx_array* res, const mlx_array logits, int axis, const int* shape, size_t shape_num, const mlx_array key , const mlx_stream s); +extern int (*mlx_random_categorical_num_samples_ptr)(mlx_array* res, const mlx_array logits_, int axis, int num_samples, const mlx_array key , const mlx_stream s); +extern int (*mlx_random_categorical_ptr)(mlx_array* res, const mlx_array logits, int axis, const mlx_array key , const mlx_stream s); +extern int (*mlx_random_gumbel_ptr)(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s); +extern int (*mlx_random_key_ptr)(mlx_array* res, uint64_t seed); +extern int (*mlx_random_laplace_ptr)(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, float loc, float scale, const mlx_array key , const mlx_stream s); +extern int (*mlx_random_multivariate_normal_ptr)(mlx_array* res, const mlx_array mean, const mlx_array cov, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s); +extern int (*mlx_random_normal_broadcast_ptr)(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array loc , const mlx_array scale , const mlx_array key , const mlx_stream s); +extern int (*mlx_random_normal_ptr)(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, float loc, float scale, const mlx_array key , const mlx_stream s); +extern int (*mlx_random_permutation_ptr)(mlx_array* res, const mlx_array x, int axis, const mlx_array key , const mlx_stream s); +extern int (*mlx_random_permutation_arange_ptr)(mlx_array* res, int x, const mlx_array key , const mlx_stream s); +extern int (*mlx_random_randint_ptr)(mlx_array* res, const mlx_array low, const mlx_array high, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s); +extern int (*mlx_random_seed_ptr)(uint64_t seed); +extern int (*mlx_random_split_num_ptr)(mlx_array* res, const mlx_array key, int num, const mlx_stream s); +extern int (*mlx_random_split_ptr)(mlx_array* res_0, mlx_array* res_1, const mlx_array key, const mlx_stream s); +extern int (*mlx_random_truncated_normal_ptr)(mlx_array* res, const mlx_array lower, const mlx_array upper, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s); +extern int (*mlx_random_uniform_ptr)(mlx_array* res, const mlx_array low, const mlx_array high, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s); +extern mlx_stream (*mlx_stream_new_ptr)(void); +extern mlx_stream (*mlx_stream_new_device_ptr)(mlx_device dev); +extern int (*mlx_stream_set_ptr)(mlx_stream* stream, const mlx_stream src); +extern int (*mlx_stream_free_ptr)(mlx_stream stream); +extern int (*mlx_stream_tostring_ptr)(mlx_string* str, mlx_stream stream); +extern bool (*mlx_stream_equal_ptr)(mlx_stream lhs, mlx_stream rhs); +extern int (*mlx_stream_get_device_ptr)(mlx_device* dev, mlx_stream stream); +extern int (*mlx_stream_get_index_ptr)(int* index, mlx_stream stream); +extern int (*mlx_synchronize_ptr)(mlx_stream stream); +extern int (*mlx_get_default_stream_ptr)(mlx_stream* stream, mlx_device dev); +extern int (*mlx_set_default_stream_ptr)(mlx_stream stream); +extern mlx_stream (*mlx_default_cpu_stream_new_ptr)(void); +extern mlx_stream (*mlx_default_gpu_stream_new_ptr)(void); +extern mlx_string (*mlx_string_new_ptr)(void); +extern mlx_string (*mlx_string_new_data_ptr)(const char* str); +extern int (*mlx_string_set_ptr)(mlx_string* str, const mlx_string src); +extern const char* (*mlx_string_data_ptr)(mlx_string str); +extern int (*mlx_string_free_ptr)(mlx_string str); +extern int (*mlx_async_eval_ptr)(const mlx_vector_array outputs); +extern int (*mlx_checkpoint_ptr)(mlx_closure* res, const mlx_closure fun); +extern int (*mlx_custom_function_ptr)(mlx_closure* res, const mlx_closure fun, const mlx_closure_custom fun_vjp , const mlx_closure_custom_jvp fun_jvp , const mlx_closure_custom_vmap fun_vmap); +extern int (*mlx_custom_vjp_ptr)(mlx_closure* res, const mlx_closure fun, const mlx_closure_custom fun_vjp); +extern int (*mlx_eval_ptr)(const mlx_vector_array outputs); +extern int (*mlx_jvp_ptr)(mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array primals, const mlx_vector_array tangents); +extern int (*mlx_value_and_grad_ptr)(mlx_closure_value_and_grad* res, const mlx_closure fun, const int* argnums, size_t argnums_num); +extern int (*mlx_vjp_ptr)(mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array primals, const mlx_vector_array cotangents); +extern int (*mlx_detail_vmap_replace_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array s_inputs, const mlx_vector_array s_outputs, const int* in_axes, size_t in_axes_num, const int* out_axes, size_t out_axes_num); +extern int (*mlx_detail_vmap_trace_ptr)(mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array inputs, const int* in_axes, size_t in_axes_num); +extern mlx_vector_array (*mlx_vector_array_new_ptr)(void); +extern int (*mlx_vector_array_set_ptr)(mlx_vector_array* vec, const mlx_vector_array src); +extern int (*mlx_vector_array_free_ptr)(mlx_vector_array vec); +extern mlx_vector_array (*mlx_vector_array_new_data_ptr)(const mlx_array* data, size_t size); +extern mlx_vector_array (*mlx_vector_array_new_value_ptr)(const mlx_array val); +extern int (*mlx_vector_array_set_data_ptr)(mlx_vector_array* vec, const mlx_array* data, size_t size); +extern int (*mlx_vector_array_set_value_ptr)(mlx_vector_array* vec, const mlx_array val); +extern int (*mlx_vector_array_append_data_ptr)(mlx_vector_array vec, const mlx_array* data, size_t size); +extern int (*mlx_vector_array_append_value_ptr)(mlx_vector_array vec, const mlx_array val); +extern size_t (*mlx_vector_array_size_ptr)(mlx_vector_array vec); +extern int (*mlx_vector_array_get_ptr)(mlx_array* res, const mlx_vector_array vec, size_t idx); +extern mlx_vector_vector_array (*mlx_vector_vector_array_new_ptr)(void); +extern int (*mlx_vector_vector_array_set_ptr)(mlx_vector_vector_array* vec, const mlx_vector_vector_array src); +extern int (*mlx_vector_vector_array_free_ptr)(mlx_vector_vector_array vec); +extern mlx_vector_vector_array (*mlx_vector_vector_array_new_data_ptr)(const mlx_vector_array* data, size_t size); +extern mlx_vector_vector_array (*mlx_vector_vector_array_new_value_ptr)(const mlx_vector_array val); +extern int (*mlx_vector_vector_array_set_data_ptr)(mlx_vector_vector_array* vec, const mlx_vector_array* data, size_t size); +extern int (*mlx_vector_vector_array_set_value_ptr)(mlx_vector_vector_array* vec, const mlx_vector_array val); +extern int (*mlx_vector_vector_array_append_data_ptr)(mlx_vector_vector_array vec, const mlx_vector_array* data, size_t size); +extern int (*mlx_vector_vector_array_append_value_ptr)(mlx_vector_vector_array vec, const mlx_vector_array val); +extern size_t (*mlx_vector_vector_array_size_ptr)(mlx_vector_vector_array vec); +extern int (*mlx_vector_vector_array_get_ptr)(mlx_vector_array* res, const mlx_vector_vector_array vec, size_t idx); +extern mlx_vector_int (*mlx_vector_int_new_ptr)(void); +extern int (*mlx_vector_int_set_ptr)(mlx_vector_int* vec, const mlx_vector_int src); +extern int (*mlx_vector_int_free_ptr)(mlx_vector_int vec); +extern mlx_vector_int (*mlx_vector_int_new_data_ptr)(int* data, size_t size); +extern mlx_vector_int (*mlx_vector_int_new_value_ptr)(int val); +extern int (*mlx_vector_int_set_data_ptr)(mlx_vector_int* vec, int* data, size_t size); +extern int (*mlx_vector_int_set_value_ptr)(mlx_vector_int* vec, int val); +extern int (*mlx_vector_int_append_data_ptr)(mlx_vector_int vec, int* data, size_t size); +extern int (*mlx_vector_int_append_value_ptr)(mlx_vector_int vec, int val); +extern size_t (*mlx_vector_int_size_ptr)(mlx_vector_int vec); +extern int (*mlx_vector_int_get_ptr)(int* res, const mlx_vector_int vec, size_t idx); +extern mlx_vector_string (*mlx_vector_string_new_ptr)(void); +extern int (*mlx_vector_string_set_ptr)(mlx_vector_string* vec, const mlx_vector_string src); +extern int (*mlx_vector_string_free_ptr)(mlx_vector_string vec); +extern mlx_vector_string (*mlx_vector_string_new_data_ptr)(const char** data, size_t size); +extern mlx_vector_string (*mlx_vector_string_new_value_ptr)(const char* val); +extern int (*mlx_vector_string_set_data_ptr)(mlx_vector_string* vec, const char** data, size_t size); +extern int (*mlx_vector_string_set_value_ptr)(mlx_vector_string* vec, const char* val); +extern int (*mlx_vector_string_append_data_ptr)(mlx_vector_string vec, const char** data, size_t size); +extern int (*mlx_vector_string_append_value_ptr)(mlx_vector_string vec, const char* val); +extern size_t (*mlx_vector_string_size_ptr)(mlx_vector_string vec); +extern int (*mlx_vector_string_get_ptr)(char** res, const mlx_vector_string vec, size_t idx); +extern int (*mlx_version_ptr)(mlx_string* str_); + +// Initialize all function pointers via dlsym (defined in mlx.c) +int mlx_load_functions(void* handle); + +// Wrapper function declarations that call through function pointers +// Go code calls these directly as C.mlx_* (no #define redirection needed) +size_t mlx_dtype_size(mlx_dtype dtype); + +int mlx_array_tostring(mlx_string* str, const mlx_array arr); + +mlx_array mlx_array_new(void); + +int mlx_array_free(mlx_array arr); + +mlx_array mlx_array_new_bool(bool val); + +mlx_array mlx_array_new_int(int val); + +mlx_array mlx_array_new_float32(float val); + +mlx_array mlx_array_new_float(float val); + +mlx_array mlx_array_new_float64(double val); + +mlx_array mlx_array_new_double(double val); + +mlx_array mlx_array_new_complex(float real_val, float imag_val); + +mlx_array mlx_array_new_data(const void* data, const int* shape, int dim, mlx_dtype dtype); + +int mlx_array_set(mlx_array* arr, const mlx_array src); + +int mlx_array_set_bool(mlx_array* arr, bool val); + +int mlx_array_set_int(mlx_array* arr, int val); + +int mlx_array_set_float32(mlx_array* arr, float val); + +int mlx_array_set_float(mlx_array* arr, float val); + +int mlx_array_set_float64(mlx_array* arr, double val); + +int mlx_array_set_double(mlx_array* arr, double val); + +int mlx_array_set_complex(mlx_array* arr, float real_val, float imag_val); + +int mlx_array_set_data(mlx_array* arr, const void* data, const int* shape, int dim, mlx_dtype dtype); + +size_t mlx_array_itemsize(const mlx_array arr); + +size_t mlx_array_size(const mlx_array arr); + +size_t mlx_array_nbytes(const mlx_array arr); + +size_t mlx_array_ndim(const mlx_array arr); + +const int* mlx_array_shape(const mlx_array arr); + +const size_t* mlx_array_strides(const mlx_array arr); + +int mlx_array_dim(const mlx_array arr, int dim); + +mlx_dtype mlx_array_dtype(const mlx_array arr); + +int mlx_array_eval(mlx_array arr); + +int mlx_array_item_bool(bool* res, const mlx_array arr); + +int mlx_array_item_uint8(uint8_t* res, const mlx_array arr); + +int mlx_array_item_uint16(uint16_t* res, const mlx_array arr); + +int mlx_array_item_uint32(uint32_t* res, const mlx_array arr); + +int mlx_array_item_uint64(uint64_t* res, const mlx_array arr); + +int mlx_array_item_int8(int8_t* res, const mlx_array arr); + +int mlx_array_item_int16(int16_t* res, const mlx_array arr); + +int mlx_array_item_int32(int32_t* res, const mlx_array arr); + +int mlx_array_item_int64(int64_t* res, const mlx_array arr); + +int mlx_array_item_float32(float* res, const mlx_array arr); + +int mlx_array_item_float64(double* res, const mlx_array arr); + +int mlx_array_item_complex64(float _Complex* res, const mlx_array arr); + +#if defined(__aarch64__) || defined(_M_ARM64) +int mlx_array_item_float16(float16_t* res, const mlx_array arr); +#endif + +#if defined(__aarch64__) || defined(_M_ARM64) +int mlx_array_item_bfloat16(bfloat16_t* res, const mlx_array arr); +#endif + +const bool* mlx_array_data_bool(const mlx_array arr); + +const uint8_t* mlx_array_data_uint8(const mlx_array arr); + +const uint16_t* mlx_array_data_uint16(const mlx_array arr); + +const uint32_t* mlx_array_data_uint32(const mlx_array arr); + +const uint64_t* mlx_array_data_uint64(const mlx_array arr); + +const int8_t* mlx_array_data_int8(const mlx_array arr); + +const int16_t* mlx_array_data_int16(const mlx_array arr); + +const int32_t* mlx_array_data_int32(const mlx_array arr); + +const int64_t* mlx_array_data_int64(const mlx_array arr); + +const float* mlx_array_data_float32(const mlx_array arr); + +const double* mlx_array_data_float64(const mlx_array arr); + +const float _Complex* mlx_array_data_complex64(const mlx_array arr); + +#if defined(__aarch64__) || defined(_M_ARM64) +const float16_t* mlx_array_data_float16(const mlx_array arr); +#endif + +#if defined(__aarch64__) || defined(_M_ARM64) +const bfloat16_t* mlx_array_data_bfloat16(const mlx_array arr); +#endif + +int _mlx_array_is_available(bool* res, const mlx_array arr); + +int _mlx_array_wait(const mlx_array arr); + +int _mlx_array_is_contiguous(bool* res, const mlx_array arr); + +int _mlx_array_is_row_contiguous(bool* res, const mlx_array arr); + +int _mlx_array_is_col_contiguous(bool* res, const mlx_array arr); + +mlx_closure mlx_closure_new(void); + +int mlx_closure_free(mlx_closure cls); + +mlx_closure mlx_closure_new_func(int (*fun)(mlx_vector_array*, const mlx_vector_array)); + +mlx_closure mlx_closure_new_func_payload(int (*fun)(mlx_vector_array*, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)); + +int mlx_closure_set(mlx_closure* cls, const mlx_closure src); + +int mlx_closure_apply(mlx_vector_array* res, mlx_closure cls, const mlx_vector_array input); + +mlx_closure mlx_closure_new_unary(int (*fun)(mlx_array*, const mlx_array)); + +mlx_closure_kwargs mlx_closure_kwargs_new(void); + +int mlx_closure_kwargs_free(mlx_closure_kwargs cls); + +mlx_closure_kwargs mlx_closure_kwargs_new_func(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_map_string_to_array)); + +mlx_closure_kwargs mlx_closure_kwargs_new_func_payload(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_map_string_to_array, void*), void* payload, void (*dtor)(void*)); + +int mlx_closure_kwargs_set(mlx_closure_kwargs* cls, const mlx_closure_kwargs src); + +int mlx_closure_kwargs_apply(mlx_vector_array* res, mlx_closure_kwargs cls, const mlx_vector_array input_0, const mlx_map_string_to_array input_1); + +mlx_closure_value_and_grad mlx_closure_value_and_grad_new(void); + +int mlx_closure_value_and_grad_free(mlx_closure_value_and_grad cls); + +mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func(int (*fun)(mlx_vector_array*, mlx_vector_array*, const mlx_vector_array)); + +mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func_payload(int (*fun)( mlx_vector_array*, mlx_vector_array*, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)); + +int mlx_closure_value_and_grad_set(mlx_closure_value_and_grad* cls, const mlx_closure_value_and_grad src); + +int mlx_closure_value_and_grad_apply(mlx_vector_array* res_0, mlx_vector_array* res_1, mlx_closure_value_and_grad cls, const mlx_vector_array input); + +mlx_closure_custom mlx_closure_custom_new(void); + +int mlx_closure_custom_free(mlx_closure_custom cls); + +mlx_closure_custom mlx_closure_custom_new_func(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const mlx_vector_array)); + +mlx_closure_custom mlx_closure_custom_new_func_payload(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)); + +int mlx_closure_custom_set(mlx_closure_custom* cls, const mlx_closure_custom src); + +int mlx_closure_custom_apply(mlx_vector_array* res, mlx_closure_custom cls, const mlx_vector_array input_0, const mlx_vector_array input_1, const mlx_vector_array input_2); + +mlx_closure_custom_jvp mlx_closure_custom_jvp_new(void); + +int mlx_closure_custom_jvp_free(mlx_closure_custom_jvp cls); + +mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const int*, size_t _num)); + +mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func_payload(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const int*, size_t _num, void*), void* payload, void (*dtor)(void*)); + +int mlx_closure_custom_jvp_set(mlx_closure_custom_jvp* cls, const mlx_closure_custom_jvp src); + +int mlx_closure_custom_jvp_apply(mlx_vector_array* res, mlx_closure_custom_jvp cls, const mlx_vector_array input_0, const mlx_vector_array input_1, const int* input_2, size_t input_2_num); + +mlx_closure_custom_vmap mlx_closure_custom_vmap_new(void); + +int mlx_closure_custom_vmap_free(mlx_closure_custom_vmap cls); + +mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func(int (*fun)( mlx_vector_array*, mlx_vector_int*, const mlx_vector_array, const int*, size_t _num)); + +mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func_payload(int (*fun)( mlx_vector_array*, mlx_vector_int*, const mlx_vector_array, const int*, size_t _num, void*), void* payload, void (*dtor)(void*)); + +int mlx_closure_custom_vmap_set(mlx_closure_custom_vmap* cls, const mlx_closure_custom_vmap src); + +int mlx_closure_custom_vmap_apply(mlx_vector_array* res_0, mlx_vector_int* res_1, mlx_closure_custom_vmap cls, const mlx_vector_array input_0, const int* input_1, size_t input_1_num); + +int mlx_compile(mlx_closure* res, const mlx_closure fun, bool shapeless); + +int mlx_detail_compile(mlx_closure* res, const mlx_closure fun, uintptr_t fun_id, bool shapeless, const uint64_t* constants, size_t constants_num); + +int mlx_detail_compile_clear_cache(void); + +int mlx_detail_compile_erase(uintptr_t fun_id); + +int mlx_disable_compile(void); + +int mlx_enable_compile(void); + +int mlx_set_compile_mode(mlx_compile_mode mode); + +mlx_device mlx_device_new(void); + +mlx_device mlx_device_new_type(mlx_device_type type, int index); + +int mlx_device_free(mlx_device dev); + +int mlx_device_set(mlx_device* dev, const mlx_device src); + +int mlx_device_tostring(mlx_string* str, mlx_device dev); + +bool mlx_device_equal(mlx_device lhs, mlx_device rhs); + +int mlx_device_get_index(int* index, mlx_device dev); + +int mlx_device_get_type(mlx_device_type* type, mlx_device dev); + +int mlx_get_default_device(mlx_device* dev); + +int mlx_set_default_device(mlx_device dev); + +int mlx_distributed_all_gather(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S); + +int mlx_distributed_all_max(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s); + +int mlx_distributed_all_min(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s); + +int mlx_distributed_all_sum(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s); + +int mlx_distributed_recv(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, int src, const mlx_distributed_group group , const mlx_stream s); + +int mlx_distributed_recv_like(mlx_array* res, const mlx_array x, int src, const mlx_distributed_group group , const mlx_stream s); + +int mlx_distributed_send(mlx_array* res, const mlx_array x, int dst, const mlx_distributed_group group , const mlx_stream s); + +int mlx_distributed_sum_scatter(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s); + +int mlx_distributed_group_rank(mlx_distributed_group group); + +int mlx_distributed_group_size(mlx_distributed_group group); + +mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key); + +bool mlx_distributed_is_available(void); + +mlx_distributed_group mlx_distributed_init(bool strict); + +void mlx_set_error_handler(mlx_error_handler_func handler, void* data, void (*dtor)(void*)); + +void _mlx_error(const char* file, const int line, const char* fmt, ...); + +int mlx_export_function(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless); + +int mlx_export_function_kwargs(const char* file, const mlx_closure_kwargs fun, const mlx_vector_array args, const mlx_map_string_to_array kwargs, bool shapeless); + +mlx_function_exporter mlx_function_exporter_new(const char* file, const mlx_closure fun, bool shapeless); + +int mlx_function_exporter_free(mlx_function_exporter xfunc); + +int mlx_function_exporter_apply(const mlx_function_exporter xfunc, const mlx_vector_array args); + +int mlx_function_exporter_apply_kwargs(const mlx_function_exporter xfunc, const mlx_vector_array args, const mlx_map_string_to_array kwargs); + +mlx_imported_function mlx_imported_function_new(const char* file); + +int mlx_imported_function_free(mlx_imported_function xfunc); + +int mlx_imported_function_apply(mlx_vector_array* res, const mlx_imported_function xfunc, const mlx_vector_array args); + +int mlx_imported_function_apply_kwargs(mlx_vector_array* res, const mlx_imported_function xfunc, const mlx_vector_array args, const mlx_map_string_to_array kwargs); + +mlx_fast_cuda_kernel_config mlx_fast_cuda_kernel_config_new(void); + +void mlx_fast_cuda_kernel_config_free(mlx_fast_cuda_kernel_config cls); + +int mlx_fast_cuda_kernel_config_add_output_arg(mlx_fast_cuda_kernel_config cls, const int* shape, size_t size, mlx_dtype dtype); + +int mlx_fast_cuda_kernel_config_set_grid(mlx_fast_cuda_kernel_config cls, int grid1, int grid2, int grid3); + +int mlx_fast_cuda_kernel_config_set_thread_group(mlx_fast_cuda_kernel_config cls, int thread1, int thread2, int thread3); + +int mlx_fast_cuda_kernel_config_set_init_value(mlx_fast_cuda_kernel_config cls, float value); + +int mlx_fast_cuda_kernel_config_set_verbose(mlx_fast_cuda_kernel_config cls, bool verbose); + +int mlx_fast_cuda_kernel_config_add_template_arg_dtype(mlx_fast_cuda_kernel_config cls, const char* name, mlx_dtype dtype); + +int mlx_fast_cuda_kernel_config_add_template_arg_int(mlx_fast_cuda_kernel_config cls, const char* name, int value); + +int mlx_fast_cuda_kernel_config_add_template_arg_bool(mlx_fast_cuda_kernel_config cls, const char* name, bool value); + +mlx_fast_cuda_kernel mlx_fast_cuda_kernel_new(const char* name, const mlx_vector_string input_names, const mlx_vector_string output_names, const char* source, const char* header, bool ensure_row_contiguous, int shared_memory); + +void mlx_fast_cuda_kernel_free(mlx_fast_cuda_kernel cls); + +int mlx_fast_cuda_kernel_apply(mlx_vector_array* outputs, mlx_fast_cuda_kernel cls, const mlx_vector_array inputs, const mlx_fast_cuda_kernel_config config, const mlx_stream stream); + +int mlx_fast_layer_norm(mlx_array* res, const mlx_array x, const mlx_array weight , const mlx_array bias , float eps, const mlx_stream s); + +mlx_fast_metal_kernel_config mlx_fast_metal_kernel_config_new(void); + +void mlx_fast_metal_kernel_config_free(mlx_fast_metal_kernel_config cls); + +int mlx_fast_metal_kernel_config_add_output_arg(mlx_fast_metal_kernel_config cls, const int* shape, size_t size, mlx_dtype dtype); + +int mlx_fast_metal_kernel_config_set_grid(mlx_fast_metal_kernel_config cls, int grid1, int grid2, int grid3); + +int mlx_fast_metal_kernel_config_set_thread_group(mlx_fast_metal_kernel_config cls, int thread1, int thread2, int thread3); + +int mlx_fast_metal_kernel_config_set_init_value(mlx_fast_metal_kernel_config cls, float value); + +int mlx_fast_metal_kernel_config_set_verbose(mlx_fast_metal_kernel_config cls, bool verbose); + +int mlx_fast_metal_kernel_config_add_template_arg_dtype(mlx_fast_metal_kernel_config cls, const char* name, mlx_dtype dtype); + +int mlx_fast_metal_kernel_config_add_template_arg_int(mlx_fast_metal_kernel_config cls, const char* name, int value); + +int mlx_fast_metal_kernel_config_add_template_arg_bool(mlx_fast_metal_kernel_config cls, const char* name, bool value); + +mlx_fast_metal_kernel mlx_fast_metal_kernel_new(const char* name, const mlx_vector_string input_names, const mlx_vector_string output_names, const char* source, const char* header, bool ensure_row_contiguous, bool atomic_outputs); + +void mlx_fast_metal_kernel_free(mlx_fast_metal_kernel cls); + +int mlx_fast_metal_kernel_apply(mlx_vector_array* outputs, mlx_fast_metal_kernel cls, const mlx_vector_array inputs, const mlx_fast_metal_kernel_config config, const mlx_stream stream); + +int mlx_fast_rms_norm(mlx_array* res, const mlx_array x, const mlx_array weight , float eps, const mlx_stream s); + +int mlx_fast_rope(mlx_array* res, const mlx_array x, int dims, bool traditional, mlx_optional_float base, float scale, int offset, const mlx_array freqs , const mlx_stream s); + +int mlx_fast_rope_dynamic(mlx_array* res, const mlx_array x, int dims, bool traditional, mlx_optional_float base, float scale, const mlx_array offset, const mlx_array freqs , const mlx_stream s); + +int mlx_fast_scaled_dot_product_attention(mlx_array* res, const mlx_array queries, const mlx_array keys, const mlx_array values, float scale, const char* mask_mode, const mlx_array mask_arr , const mlx_array sinks , const mlx_stream s); + +int mlx_fft_fft(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s); + +int mlx_fft_fft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); + +int mlx_fft_fftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); + +int mlx_fft_fftshift(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s); + +int mlx_fft_ifft(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s); + +int mlx_fft_ifft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); + +int mlx_fft_ifftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); + +int mlx_fft_ifftshift(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s); + +int mlx_fft_irfft(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s); + +int mlx_fft_irfft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); + +int mlx_fft_irfftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); + +int mlx_fft_rfft(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s); + +int mlx_fft_rfft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); + +int mlx_fft_rfftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s); + +int mlx_load_reader(mlx_array* res, mlx_io_reader in_stream, const mlx_stream s); + +int mlx_load(mlx_array* res, const char* file, const mlx_stream s); + +int mlx_load_safetensors_reader(mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, mlx_io_reader in_stream, const mlx_stream s); + +int mlx_load_safetensors(mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, const char* file, const mlx_stream s); + +int mlx_save_writer(mlx_io_writer out_stream, const mlx_array a); + +int mlx_save(const char* file, const mlx_array a); + +int mlx_save_safetensors_writer(mlx_io_writer in_stream, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata); + +int mlx_save_safetensors(const char* file, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata); + +mlx_io_reader mlx_io_reader_new(void* desc, mlx_io_vtable vtable); + +int mlx_io_reader_descriptor(void** desc_, mlx_io_reader io); + +int mlx_io_reader_tostring(mlx_string* str_, mlx_io_reader io); + +int mlx_io_reader_free(mlx_io_reader io); + +mlx_io_writer mlx_io_writer_new(void* desc, mlx_io_vtable vtable); + +int mlx_io_writer_descriptor(void** desc_, mlx_io_writer io); + +int mlx_io_writer_tostring(mlx_string* str_, mlx_io_writer io); + +int mlx_io_writer_free(mlx_io_writer io); + +int mlx_linalg_cholesky(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s); + +int mlx_linalg_cholesky_inv(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s); + +int mlx_linalg_cross(mlx_array* res, const mlx_array a, const mlx_array b, int axis, const mlx_stream s); + +int mlx_linalg_eig(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s); + +int mlx_linalg_eigh(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const char* UPLO, const mlx_stream s); + +int mlx_linalg_eigvals(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_linalg_eigvalsh(mlx_array* res, const mlx_array a, const char* UPLO, const mlx_stream s); + +int mlx_linalg_inv(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_linalg_lu(mlx_vector_array* res, const mlx_array a, const mlx_stream s); + +int mlx_linalg_lu_factor(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s); + +int mlx_linalg_norm(mlx_array* res, const mlx_array a, double ord, const int* axis , size_t axis_num, bool keepdims, const mlx_stream s); + +int mlx_linalg_norm_matrix(mlx_array* res, const mlx_array a, const char* ord, const int* axis , size_t axis_num, bool keepdims, const mlx_stream s); + +int mlx_linalg_norm_l2(mlx_array* res, const mlx_array a, const int* axis , size_t axis_num, bool keepdims, const mlx_stream s); + +int mlx_linalg_pinv(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_linalg_qr(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s); + +int mlx_linalg_solve(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); + +int mlx_linalg_solve_triangular(mlx_array* res, const mlx_array a, const mlx_array b, bool upper, const mlx_stream s); + +int mlx_linalg_svd(mlx_vector_array* res, const mlx_array a, bool compute_uv, const mlx_stream s); + +int mlx_linalg_tri_inv(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s); + +mlx_map_string_to_array mlx_map_string_to_array_new(void); + +int mlx_map_string_to_array_set(mlx_map_string_to_array* map, const mlx_map_string_to_array src); + +int mlx_map_string_to_array_free(mlx_map_string_to_array map); + +int mlx_map_string_to_array_insert(mlx_map_string_to_array map, const char* key, const mlx_array value); + +int mlx_map_string_to_array_get(mlx_array* value, const mlx_map_string_to_array map, const char* key); + +mlx_map_string_to_array_iterator mlx_map_string_to_array_iterator_new(mlx_map_string_to_array map); + +int mlx_map_string_to_array_iterator_free(mlx_map_string_to_array_iterator it); + +int mlx_map_string_to_array_iterator_next(const char** key, mlx_array* value, mlx_map_string_to_array_iterator it); + +mlx_map_string_to_string mlx_map_string_to_string_new(void); + +int mlx_map_string_to_string_set(mlx_map_string_to_string* map, const mlx_map_string_to_string src); + +int mlx_map_string_to_string_free(mlx_map_string_to_string map); + +int mlx_map_string_to_string_insert(mlx_map_string_to_string map, const char* key, const char* value); + +int mlx_map_string_to_string_get(const char** value, const mlx_map_string_to_string map, const char* key); + +mlx_map_string_to_string_iterator mlx_map_string_to_string_iterator_new(mlx_map_string_to_string map); + +int mlx_map_string_to_string_iterator_free(mlx_map_string_to_string_iterator it); + +int mlx_map_string_to_string_iterator_next(const char** key, const char** value, mlx_map_string_to_string_iterator it); + +int mlx_clear_cache(void); + +int mlx_get_active_memory(size_t* res); + +int mlx_get_cache_memory(size_t* res); + +int mlx_get_memory_limit(size_t* res); + +int mlx_get_peak_memory(size_t* res); + +int mlx_reset_peak_memory(void); + +int mlx_set_cache_limit(size_t* res, size_t limit); + +int mlx_set_memory_limit(size_t* res, size_t limit); + +int mlx_set_wired_limit(size_t* res, size_t limit); + +mlx_metal_device_info_t mlx_metal_device_info(void); + +int mlx_metal_is_available(bool* res); + +int mlx_metal_start_capture(const char* path); + +int mlx_metal_stop_capture(void); + +int mlx_abs(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_add(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); + +int mlx_addmm(mlx_array* res, const mlx_array c, const mlx_array a, const mlx_array b, float alpha, float beta, const mlx_stream s); + +int mlx_all_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s); + +int mlx_all_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s); + +int mlx_all(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s); + +int mlx_allclose(mlx_array* res, const mlx_array a, const mlx_array b, double rtol, double atol, bool equal_nan, const mlx_stream s); + +int mlx_any_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s); + +int mlx_any_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s); + +int mlx_any(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s); + +int mlx_arange(mlx_array* res, double start, double stop, double step, mlx_dtype dtype, const mlx_stream s); + +int mlx_arccos(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_arccosh(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_arcsin(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_arcsinh(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_arctan(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_arctan2(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); + +int mlx_arctanh(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_argmax_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s); + +int mlx_argmax(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s); + +int mlx_argmin_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s); + +int mlx_argmin(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s); + +int mlx_argpartition_axis(mlx_array* res, const mlx_array a, int kth, int axis, const mlx_stream s); + +int mlx_argpartition(mlx_array* res, const mlx_array a, int kth, const mlx_stream s); + +int mlx_argsort_axis(mlx_array* res, const mlx_array a, int axis, const mlx_stream s); + +int mlx_argsort(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_array_equal(mlx_array* res, const mlx_array a, const mlx_array b, bool equal_nan, const mlx_stream s); + +int mlx_as_strided(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const int64_t* strides, size_t strides_num, size_t offset, const mlx_stream s); + +int mlx_astype(mlx_array* res, const mlx_array a, mlx_dtype dtype, const mlx_stream s); + +int mlx_atleast_1d(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_bitwise_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); + +int mlx_bitwise_invert(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_bitwise_or(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); + +int mlx_bitwise_xor(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); + +int mlx_block_masked_mm(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s); + +int mlx_broadcast_arrays(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s); + +int mlx_broadcast_to(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s); + +int mlx_ceil(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_clip(mlx_array* res, const mlx_array a, const mlx_array a_min , const mlx_array a_max , const mlx_stream s); + +int mlx_concatenate_axis(mlx_array* res, const mlx_vector_array arrays, int axis, const mlx_stream s); + +int mlx_concatenate(mlx_array* res, const mlx_vector_array arrays, const mlx_stream s); + +int mlx_conjugate(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_contiguous(mlx_array* res, const mlx_array a, bool allow_col_major, const mlx_stream s); + +int mlx_conv1d(mlx_array* res, const mlx_array input, const mlx_array weight, int stride, int padding, int dilation, int groups, const mlx_stream s); + +int mlx_conv2d(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int padding_0, int padding_1, int dilation_0, int dilation_1, int groups, const mlx_stream s); + +int mlx_conv3d(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int stride_2, int padding_0, int padding_1, int padding_2, int dilation_0, int dilation_1, int dilation_2, int groups, const mlx_stream s); + +int mlx_conv_general(mlx_array* res, const mlx_array input, const mlx_array weight, const int* stride, size_t stride_num, const int* padding_lo, size_t padding_lo_num, const int* padding_hi, size_t padding_hi_num, const int* kernel_dilation, size_t kernel_dilation_num, const int* input_dilation, size_t input_dilation_num, int groups, bool flip, const mlx_stream s); + +int mlx_conv_transpose1d(mlx_array* res, const mlx_array input, const mlx_array weight, int stride, int padding, int dilation, int output_padding, int groups, const mlx_stream s); + +int mlx_conv_transpose2d(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int padding_0, int padding_1, int dilation_0, int dilation_1, int output_padding_0, int output_padding_1, int groups, const mlx_stream s); + +int mlx_conv_transpose3d(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int stride_2, int padding_0, int padding_1, int padding_2, int dilation_0, int dilation_1, int dilation_2, int output_padding_0, int output_padding_1, int output_padding_2, int groups, const mlx_stream s); + +int mlx_copy(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_cos(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_cosh(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_cummax(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s); + +int mlx_cummin(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s); + +int mlx_cumprod(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s); + +int mlx_cumsum(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s); + +int mlx_degrees(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_depends(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies); + +int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s); + +int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s); + +int mlx_diagonal(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s); + +int mlx_divide(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); + +int mlx_divmod(mlx_vector_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); + +int mlx_einsum(mlx_array* res, const char* subscripts, const mlx_vector_array operands, const mlx_stream s); + +int mlx_equal(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); + +int mlx_erf(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_erfinv(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_exp(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_expand_dims_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s); + +int mlx_expand_dims(mlx_array* res, const mlx_array a, int axis, const mlx_stream s); + +int mlx_expm1(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_eye(mlx_array* res, int n, int m, int k, mlx_dtype dtype, const mlx_stream s); + +int mlx_flatten(mlx_array* res, const mlx_array a, int start_axis, int end_axis, const mlx_stream s); + +int mlx_floor(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_floor_divide(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); + +int mlx_from_fp8(mlx_array* res, const mlx_array x, mlx_dtype dtype, const mlx_stream s); + +int mlx_full(mlx_array* res, const int* shape, size_t shape_num, const mlx_array vals, mlx_dtype dtype, const mlx_stream s); + +int mlx_full_like(mlx_array* res, const mlx_array a, const mlx_array vals, mlx_dtype dtype, const mlx_stream s); + +int mlx_gather(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const int* axes, size_t axes_num, const int* slice_sizes, size_t slice_sizes_num, const mlx_stream s); + +int mlx_gather_single(mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const int* slice_sizes, size_t slice_sizes_num, const mlx_stream s); + +int mlx_gather_mm(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_array lhs_indices , const mlx_array rhs_indices , bool sorted_indices, const mlx_stream s); + +int mlx_gather_qmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , const mlx_array lhs_indices , const mlx_array rhs_indices , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, bool sorted_indices, const mlx_stream s); + +int mlx_greater(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); + +int mlx_greater_equal(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); + +int mlx_hadamard_transform(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s); + +int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s); + +int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_inner(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); + +int mlx_isclose(mlx_array* res, const mlx_array a, const mlx_array b, double rtol, double atol, bool equal_nan, const mlx_stream s); + +int mlx_isfinite(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_isinf(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_isnan(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_isneginf(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_isposinf(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_kron(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); + +int mlx_left_shift(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); + +int mlx_less(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); + +int mlx_less_equal(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); + +int mlx_linspace(mlx_array* res, double start, double stop, int num, mlx_dtype dtype, const mlx_stream s); + +int mlx_log(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_log10(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_log1p(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_log2(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_logaddexp(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); + +int mlx_logcumsumexp(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s); + +int mlx_logical_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); + +int mlx_logical_not(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_logical_or(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); + +int mlx_logsumexp_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s); + +int mlx_logsumexp_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s); + +int mlx_logsumexp(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s); + +int mlx_masked_scatter(mlx_array* res, const mlx_array a, const mlx_array mask, const mlx_array src, const mlx_stream s); + +int mlx_matmul(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); + +int mlx_max_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s); + +int mlx_max_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s); + +int mlx_max(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s); + +int mlx_maximum(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); + +int mlx_mean_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s); + +int mlx_mean_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s); + +int mlx_mean(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s); + +int mlx_median(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s); + +int mlx_meshgrid(mlx_vector_array* res, const mlx_vector_array arrays, bool sparse, const char* indexing, const mlx_stream s); + +int mlx_min_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s); + +int mlx_min_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s); + +int mlx_min(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s); + +int mlx_minimum(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); + +int mlx_moveaxis(mlx_array* res, const mlx_array a, int source, int destination, const mlx_stream s); + +int mlx_multiply(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); + +int mlx_nan_to_num(mlx_array* res, const mlx_array a, float nan, mlx_optional_float posinf, mlx_optional_float neginf, const mlx_stream s); + +int mlx_negative(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_not_equal(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); + +int mlx_number_of_elements(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool inverted, mlx_dtype dtype, const mlx_stream s); + +int mlx_ones(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_stream s); + +int mlx_ones_like(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_outer(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); + +int mlx_pad(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const int* low_pad_size, size_t low_pad_size_num, const int* high_pad_size, size_t high_pad_size_num, const mlx_array pad_value, const char* mode, const mlx_stream s); + +int mlx_pad_symmetric(mlx_array* res, const mlx_array a, int pad_width, const mlx_array pad_value, const char* mode, const mlx_stream s); + +int mlx_partition_axis(mlx_array* res, const mlx_array a, int kth, int axis, const mlx_stream s); + +int mlx_partition(mlx_array* res, const mlx_array a, int kth, const mlx_stream s); + +int mlx_power(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); + +int mlx_prod_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s); + +int mlx_prod_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s); + +int mlx_prod(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s); + +int mlx_put_along_axis(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s); + +int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s); + +int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s); + +int mlx_quantized_matmul(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s); + +int mlx_radians(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_real(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_reciprocal(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_remainder(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); + +int mlx_repeat_axis(mlx_array* res, const mlx_array arr, int repeats, int axis, const mlx_stream s); + +int mlx_repeat(mlx_array* res, const mlx_array arr, int repeats, const mlx_stream s); + +int mlx_reshape(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s); + +int mlx_right_shift(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); + +int mlx_roll_axis(mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, int axis, const mlx_stream s); + +int mlx_roll_axes(mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, const int* axes, size_t axes_num, const mlx_stream s); + +int mlx_roll(mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, const mlx_stream s); + +int mlx_round(mlx_array* res, const mlx_array a, int decimals, const mlx_stream s); + +int mlx_rsqrt(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_scatter(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s); + +int mlx_scatter_single(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s); + +int mlx_scatter_add(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s); + +int mlx_scatter_add_single(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s); + +int mlx_scatter_add_axis(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s); + +int mlx_scatter_max(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s); + +int mlx_scatter_max_single(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s); + +int mlx_scatter_min(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s); + +int mlx_scatter_min_single(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s); + +int mlx_scatter_prod(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s); + +int mlx_scatter_prod_single(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s); + +int mlx_segmented_mm(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_array segments, const mlx_stream s); + +int mlx_sigmoid(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_sign(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_sin(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_sinh(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_slice(mlx_array* res, const mlx_array a, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s); + +int mlx_slice_dynamic(mlx_array* res, const mlx_array a, const mlx_array start, const int* axes, size_t axes_num, const int* slice_size, size_t slice_size_num, const mlx_stream s); + +int mlx_slice_update(mlx_array* res, const mlx_array src, const mlx_array update, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s); + +int mlx_slice_update_dynamic(mlx_array* res, const mlx_array src, const mlx_array update, const mlx_array start, const int* axes, size_t axes_num, const mlx_stream s); + +int mlx_softmax_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool precise, const mlx_stream s); + +int mlx_softmax_axis(mlx_array* res, const mlx_array a, int axis, bool precise, const mlx_stream s); + +int mlx_softmax(mlx_array* res, const mlx_array a, bool precise, const mlx_stream s); + +int mlx_sort_axis(mlx_array* res, const mlx_array a, int axis, const mlx_stream s); + +int mlx_sort(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_split(mlx_vector_array* res, const mlx_array a, int num_splits, int axis, const mlx_stream s); + +int mlx_split_sections(mlx_vector_array* res, const mlx_array a, const int* indices, size_t indices_num, int axis, const mlx_stream s); + +int mlx_sqrt(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_square(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_squeeze_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s); + +int mlx_squeeze_axis(mlx_array* res, const mlx_array a, int axis, const mlx_stream s); + +int mlx_squeeze(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_stack_axis(mlx_array* res, const mlx_vector_array arrays, int axis, const mlx_stream s); + +int mlx_stack(mlx_array* res, const mlx_vector_array arrays, const mlx_stream s); + +int mlx_std_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, int ddof, const mlx_stream s); + +int mlx_std_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, int ddof, const mlx_stream s); + +int mlx_std(mlx_array* res, const mlx_array a, bool keepdims, int ddof, const mlx_stream s); + +int mlx_stop_gradient(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_subtract(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s); + +int mlx_sum_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s); + +int mlx_sum_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s); + +int mlx_sum(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s); + +int mlx_swapaxes(mlx_array* res, const mlx_array a, int axis1, int axis2, const mlx_stream s); + +int mlx_take_axis(mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const mlx_stream s); + +int mlx_take(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_stream s); + +int mlx_take_along_axis(mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const mlx_stream s); + +int mlx_tan(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_tanh(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_tensordot(mlx_array* res, const mlx_array a, const mlx_array b, const int* axes_a, size_t axes_a_num, const int* axes_b, size_t axes_b_num, const mlx_stream s); + +int mlx_tensordot_axis(mlx_array* res, const mlx_array a, const mlx_array b, int axis, const mlx_stream s); + +int mlx_tile(mlx_array* res, const mlx_array arr, const int* reps, size_t reps_num, const mlx_stream s); + +int mlx_to_fp8(mlx_array* res, const mlx_array x, const mlx_stream s); + +int mlx_topk_axis(mlx_array* res, const mlx_array a, int k, int axis, const mlx_stream s); + +int mlx_topk(mlx_array* res, const mlx_array a, int k, const mlx_stream s); + +int mlx_trace(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, mlx_dtype dtype, const mlx_stream s); + +int mlx_transpose_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s); + +int mlx_transpose(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_tri(mlx_array* res, int n, int m, int k, mlx_dtype type, const mlx_stream s); + +int mlx_tril(mlx_array* res, const mlx_array x, int k, const mlx_stream s); + +int mlx_triu(mlx_array* res, const mlx_array x, int k, const mlx_stream s); + +int mlx_unflatten(mlx_array* res, const mlx_array a, int axis, const int* shape, size_t shape_num, const mlx_stream s); + +int mlx_var_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, int ddof, const mlx_stream s); + +int mlx_var_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, int ddof, const mlx_stream s); + +int mlx_var(mlx_array* res, const mlx_array a, bool keepdims, int ddof, const mlx_stream s); + +int mlx_view(mlx_array* res, const mlx_array a, mlx_dtype dtype, const mlx_stream s); + +int mlx_where(mlx_array* res, const mlx_array condition, const mlx_array x, const mlx_array y, const mlx_stream s); + +int mlx_zeros(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_stream s); + +int mlx_zeros_like(mlx_array* res, const mlx_array a, const mlx_stream s); + +int mlx_random_bernoulli(mlx_array* res, const mlx_array p, const int* shape, size_t shape_num, const mlx_array key , const mlx_stream s); + +int mlx_random_bits(mlx_array* res, const int* shape, size_t shape_num, int width, const mlx_array key , const mlx_stream s); + +int mlx_random_categorical_shape(mlx_array* res, const mlx_array logits, int axis, const int* shape, size_t shape_num, const mlx_array key , const mlx_stream s); + +int mlx_random_categorical_num_samples(mlx_array* res, const mlx_array logits_, int axis, int num_samples, const mlx_array key , const mlx_stream s); + +int mlx_random_categorical(mlx_array* res, const mlx_array logits, int axis, const mlx_array key , const mlx_stream s); + +int mlx_random_gumbel(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s); + +int mlx_random_key(mlx_array* res, uint64_t seed); + +int mlx_random_laplace(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, float loc, float scale, const mlx_array key , const mlx_stream s); + +int mlx_random_multivariate_normal(mlx_array* res, const mlx_array mean, const mlx_array cov, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s); + +int mlx_random_normal_broadcast(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array loc , const mlx_array scale , const mlx_array key , const mlx_stream s); + +int mlx_random_normal(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, float loc, float scale, const mlx_array key , const mlx_stream s); + +int mlx_random_permutation(mlx_array* res, const mlx_array x, int axis, const mlx_array key , const mlx_stream s); + +int mlx_random_permutation_arange(mlx_array* res, int x, const mlx_array key , const mlx_stream s); + +int mlx_random_randint(mlx_array* res, const mlx_array low, const mlx_array high, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s); + +int mlx_random_seed(uint64_t seed); + +int mlx_random_split_num(mlx_array* res, const mlx_array key, int num, const mlx_stream s); + +int mlx_random_split(mlx_array* res_0, mlx_array* res_1, const mlx_array key, const mlx_stream s); + +int mlx_random_truncated_normal(mlx_array* res, const mlx_array lower, const mlx_array upper, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s); + +int mlx_random_uniform(mlx_array* res, const mlx_array low, const mlx_array high, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s); + +mlx_stream mlx_stream_new(void); + +mlx_stream mlx_stream_new_device(mlx_device dev); + +int mlx_stream_set(mlx_stream* stream, const mlx_stream src); + +int mlx_stream_free(mlx_stream stream); + +int mlx_stream_tostring(mlx_string* str, mlx_stream stream); + +bool mlx_stream_equal(mlx_stream lhs, mlx_stream rhs); + +int mlx_stream_get_device(mlx_device* dev, mlx_stream stream); + +int mlx_stream_get_index(int* index, mlx_stream stream); + +int mlx_synchronize(mlx_stream stream); + +int mlx_get_default_stream(mlx_stream* stream, mlx_device dev); + +int mlx_set_default_stream(mlx_stream stream); + +mlx_stream mlx_default_cpu_stream_new(void); + +mlx_stream mlx_default_gpu_stream_new(void); + +mlx_string mlx_string_new(void); + +mlx_string mlx_string_new_data(const char* str); + +int mlx_string_set(mlx_string* str, const mlx_string src); + +const char* mlx_string_data(mlx_string str); + +int mlx_string_free(mlx_string str); + +int mlx_async_eval(const mlx_vector_array outputs); + +int mlx_checkpoint(mlx_closure* res, const mlx_closure fun); + +int mlx_custom_function(mlx_closure* res, const mlx_closure fun, const mlx_closure_custom fun_vjp , const mlx_closure_custom_jvp fun_jvp , const mlx_closure_custom_vmap fun_vmap); + +int mlx_custom_vjp(mlx_closure* res, const mlx_closure fun, const mlx_closure_custom fun_vjp); + +int mlx_eval(const mlx_vector_array outputs); + +int mlx_jvp(mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array primals, const mlx_vector_array tangents); + +int mlx_value_and_grad(mlx_closure_value_and_grad* res, const mlx_closure fun, const int* argnums, size_t argnums_num); + +int mlx_vjp(mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array primals, const mlx_vector_array cotangents); + +int mlx_detail_vmap_replace(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array s_inputs, const mlx_vector_array s_outputs, const int* in_axes, size_t in_axes_num, const int* out_axes, size_t out_axes_num); + +int mlx_detail_vmap_trace(mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array inputs, const int* in_axes, size_t in_axes_num); + +mlx_vector_array mlx_vector_array_new(void); + +int mlx_vector_array_set(mlx_vector_array* vec, const mlx_vector_array src); + +int mlx_vector_array_free(mlx_vector_array vec); + +mlx_vector_array mlx_vector_array_new_data(const mlx_array* data, size_t size); + +mlx_vector_array mlx_vector_array_new_value(const mlx_array val); + +int mlx_vector_array_set_data(mlx_vector_array* vec, const mlx_array* data, size_t size); + +int mlx_vector_array_set_value(mlx_vector_array* vec, const mlx_array val); + +int mlx_vector_array_append_data(mlx_vector_array vec, const mlx_array* data, size_t size); + +int mlx_vector_array_append_value(mlx_vector_array vec, const mlx_array val); + +size_t mlx_vector_array_size(mlx_vector_array vec); + +int mlx_vector_array_get(mlx_array* res, const mlx_vector_array vec, size_t idx); + +mlx_vector_vector_array mlx_vector_vector_array_new(void); + +int mlx_vector_vector_array_set(mlx_vector_vector_array* vec, const mlx_vector_vector_array src); + +int mlx_vector_vector_array_free(mlx_vector_vector_array vec); + +mlx_vector_vector_array mlx_vector_vector_array_new_data(const mlx_vector_array* data, size_t size); + +mlx_vector_vector_array mlx_vector_vector_array_new_value(const mlx_vector_array val); + +int mlx_vector_vector_array_set_data(mlx_vector_vector_array* vec, const mlx_vector_array* data, size_t size); + +int mlx_vector_vector_array_set_value(mlx_vector_vector_array* vec, const mlx_vector_array val); + +int mlx_vector_vector_array_append_data(mlx_vector_vector_array vec, const mlx_vector_array* data, size_t size); + +int mlx_vector_vector_array_append_value(mlx_vector_vector_array vec, const mlx_vector_array val); + +size_t mlx_vector_vector_array_size(mlx_vector_vector_array vec); + +int mlx_vector_vector_array_get(mlx_vector_array* res, const mlx_vector_vector_array vec, size_t idx); + +mlx_vector_int mlx_vector_int_new(void); + +int mlx_vector_int_set(mlx_vector_int* vec, const mlx_vector_int src); + +int mlx_vector_int_free(mlx_vector_int vec); + +mlx_vector_int mlx_vector_int_new_data(int* data, size_t size); + +mlx_vector_int mlx_vector_int_new_value(int val); + +int mlx_vector_int_set_data(mlx_vector_int* vec, int* data, size_t size); + +int mlx_vector_int_set_value(mlx_vector_int* vec, int val); + +int mlx_vector_int_append_data(mlx_vector_int vec, int* data, size_t size); + +int mlx_vector_int_append_value(mlx_vector_int vec, int val); + +size_t mlx_vector_int_size(mlx_vector_int vec); + +int mlx_vector_int_get(int* res, const mlx_vector_int vec, size_t idx); + +mlx_vector_string mlx_vector_string_new(void); + +int mlx_vector_string_set(mlx_vector_string* vec, const mlx_vector_string src); + +int mlx_vector_string_free(mlx_vector_string vec); + +mlx_vector_string mlx_vector_string_new_data(const char** data, size_t size); + +mlx_vector_string mlx_vector_string_new_value(const char* val); + +int mlx_vector_string_set_data(mlx_vector_string* vec, const char** data, size_t size); + +int mlx_vector_string_set_value(mlx_vector_string* vec, const char* val); + +int mlx_vector_string_append_data(mlx_vector_string vec, const char** data, size_t size); + +int mlx_vector_string_append_value(mlx_vector_string vec, const char* val); + +size_t mlx_vector_string_size(mlx_vector_string vec); + +int mlx_vector_string_get(char** res, const mlx_vector_string vec, size_t idx); + +int mlx_version(mlx_string* str_); + +#endif // MLX_WRAPPERS_H diff --git a/x/imagegen/mlx/mlx_dynamic.c b/x/imagegen/mlx/mlx_dynamic.c new file mode 100644 index 000000000..aedef7a01 --- /dev/null +++ b/x/imagegen/mlx/mlx_dynamic.c @@ -0,0 +1,144 @@ +// mlx_dynamic.c - Dynamic loading wrapper for MLX-C library +// This file provides runtime dynamic loading of libmlxc instead of link-time binding + +#include "mlx_dynamic.h" +#include +#include +#include + +#ifdef _WIN32 +#include +typedef HMODULE lib_handle_t; +#define LOAD_LIB(path) LoadLibraryA(path) +#define GET_SYMBOL(handle, name) GetProcAddress(handle, name) +#define CLOSE_LIB(handle) FreeLibrary(handle) +#define LIB_ERROR() "LoadLibrary failed" +#else +#include +typedef void* lib_handle_t; +#define LOAD_LIB(path) dlopen(path, RTLD_LAZY | RTLD_GLOBAL) +#define GET_SYMBOL(handle, name) dlsym(handle, name) +#define CLOSE_LIB(handle) dlclose(handle) +#define LIB_ERROR() dlerror() +#ifdef __APPLE__ +#include +#include +#endif +#endif + +static lib_handle_t mlx_handle = NULL; +static int mlx_initialized = 0; +static char mlx_error_buffer[512] = {0}; + +#ifdef __APPLE__ +// Get path to library in same directory as executable +static char* get_exe_relative_path(const char* libname) { + static char path[1024]; + uint32_t size = sizeof(path); + if (_NSGetExecutablePath(path, &size) != 0) { + return NULL; + } + // Get directory of executable + char* dir = dirname(path); + static char fullpath[1024]; + snprintf(fullpath, sizeof(fullpath), "%s/%s", dir, libname); + return fullpath; +} +#endif + +// Try to load library from a specific path +static int try_load_lib(const char* path) { + if (!path) return 0; + mlx_handle = LOAD_LIB(path); + return mlx_handle != NULL; +} + +// Initialize MLX dynamic library +// Returns 0 on success, -1 on failure +// On failure, call mlx_dynamic_error() to get error message +int mlx_dynamic_init(void) { + if (mlx_initialized) { + return 0; // Already initialized + } + + const char* lib_path = NULL; + const char* tried_paths[8] = {0}; + int num_tried = 0; + +#ifdef _WIN32 + // Windows: try same directory as executable + lib_path = "libmlxc.dll"; + tried_paths[num_tried++] = lib_path; + if (try_load_lib(lib_path)) goto success; +#elif defined(__APPLE__) + // macOS: try executable directory first + lib_path = get_exe_relative_path("libmlxc.dylib"); + if (lib_path) { + tried_paths[num_tried++] = lib_path; + if (try_load_lib(lib_path)) goto success; + } + // Try build directory (for tests run from repo root) + lib_path = "./build/lib/ollama/libmlxc.dylib"; + tried_paths[num_tried++] = lib_path; + if (try_load_lib(lib_path)) goto success; + // Fallback to system paths + lib_path = "libmlxc.dylib"; + tried_paths[num_tried++] = lib_path; + if (try_load_lib(lib_path)) goto success; +#else + // Linux: try build directory first (for tests) + lib_path = "./build/lib/ollama/libmlxc.so"; + tried_paths[num_tried++] = lib_path; + if (try_load_lib(lib_path)) goto success; + // Fallback to system paths + lib_path = "libmlxc.so"; + tried_paths[num_tried++] = lib_path; + if (try_load_lib(lib_path)) goto success; +#endif + + // Failed to load library - build error message with all tried paths + { + const char* err = LIB_ERROR(); + int offset = snprintf(mlx_error_buffer, sizeof(mlx_error_buffer), + "MLX: Failed to load libmlxc library. Tried: "); + for (int i = 0; i < num_tried && offset < (int)sizeof(mlx_error_buffer) - 50; i++) { + offset += snprintf(mlx_error_buffer + offset, sizeof(mlx_error_buffer) - offset, + "%s%s", i > 0 ? ", " : "", tried_paths[i]); + } + if (err) { + snprintf(mlx_error_buffer + offset, sizeof(mlx_error_buffer) - offset, + ". Last error: %s", err); + } + } + return -1; + +success: + mlx_initialized = 1; + snprintf(mlx_error_buffer, sizeof(mlx_error_buffer), + "MLX: Successfully loaded %s", lib_path ? lib_path : "library"); + return 0; +} + +// Get the last error message +const char* mlx_dynamic_error(void) { + return mlx_error_buffer; +} + +// Check if MLX is initialized +int mlx_dynamic_is_initialized(void) { + return mlx_initialized; +} + +// Get the library handle (for use by generated wrappers) +void* mlx_get_handle(void) { + return mlx_handle; +} + +// Cleanup (optional, called at program exit) +void mlx_dynamic_cleanup(void) { + if (mlx_handle != NULL) { + CLOSE_LIB(mlx_handle); + mlx_handle = NULL; + mlx_initialized = 0; + } +} diff --git a/x/imagegen/mlx/mlx_dynamic.h b/x/imagegen/mlx/mlx_dynamic.h new file mode 100644 index 000000000..9ca1473f9 --- /dev/null +++ b/x/imagegen/mlx/mlx_dynamic.h @@ -0,0 +1,29 @@ +// mlx_dynamic.h - Dynamic loading interface for MLX-C library +#ifndef MLX_DYNAMIC_H +#define MLX_DYNAMIC_H + +#ifdef __cplusplus +extern "C" { +#endif + +// Initialize the MLX dynamic library +// Returns 0 on success, -1 on failure +int mlx_dynamic_init(void); + +// Get the last error message from dynamic loading +const char* mlx_dynamic_error(void); + +// Check if MLX is initialized +int mlx_dynamic_is_initialized(void); + +// Get the library handle (for use by generated wrappers) +void* mlx_get_handle(void); + +// Cleanup resources (optional, for clean shutdown) +void mlx_dynamic_cleanup(void); + +#ifdef __cplusplus +} +#endif + +#endif // MLX_DYNAMIC_H diff --git a/x/imagegen/mlx/mlx_test.go b/x/imagegen/mlx/mlx_test.go index db8fe394f..37b3ac63b 100644 --- a/x/imagegen/mlx/mlx_test.go +++ b/x/imagegen/mlx/mlx_test.go @@ -4,9 +4,30 @@ package mlx import ( "fmt" + "os" + "path/filepath" + "runtime" "testing" ) +// TestMain initializes MLX before running tests. +// If MLX libraries are not available, tests are skipped. +func TestMain(m *testing.M) { + // Change to repo root so ./build/lib/ollama/ path works + _, thisFile, _, _ := runtime.Caller(0) + repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..") + if err := os.Chdir(repoRoot); err != nil { + fmt.Printf("Failed to change to repo root: %v\n", err) + os.Exit(1) + } + + if err := InitMLX(); err != nil { + fmt.Printf("Skipping MLX tests: %v\n", err) + os.Exit(0) + } + os.Exit(m.Run()) +} + // TestBasicCleanup verifies non-kept arrays are freed and kept arrays survive. func TestBasicCleanup(t *testing.T) { weight := NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{2, 2}) diff --git a/x/imagegen/models/qwen_image/pipeline_test.go b/x/imagegen/models/qwen_image/pipeline_test.go index 5625427f5..4a0ad7135 100644 --- a/x/imagegen/models/qwen_image/pipeline_test.go +++ b/x/imagegen/models/qwen_image/pipeline_test.go @@ -3,12 +3,33 @@ package qwen_image import ( + "fmt" "os" + "path/filepath" + "runtime" "testing" "github.com/ollama/ollama/x/imagegen/mlx" ) +// TestMain initializes MLX before running tests. +// If MLX libraries are not available, tests are skipped. +func TestMain(m *testing.M) { + // Change to repo root so ./build/lib/ollama/ path works + _, thisFile, _, _ := runtime.Caller(0) + repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..", "..") + if err := os.Chdir(repoRoot); err != nil { + fmt.Printf("Failed to change to repo root: %v\n", err) + os.Exit(1) + } + + if err := mlx.InitMLX(); err != nil { + fmt.Printf("Skipping qwen_image tests: %v\n", err) + os.Exit(0) + } + os.Exit(m.Run()) +} + // TestPipelineOutput runs the full pipeline (integration test). // Skips if model weights not found. Requires ~50GB VRAM. func TestPipelineOutput(t *testing.T) { diff --git a/x/imagegen/models/qwen_image/qwen_image.go b/x/imagegen/models/qwen_image/qwen_image.go index c6a69d38f..32f0cac54 100644 --- a/x/imagegen/models/qwen_image/qwen_image.go +++ b/x/imagegen/models/qwen_image/qwen_image.go @@ -9,7 +9,6 @@ import ( "path/filepath" "time" - "github.com/ollama/ollama/x/imagegen" "github.com/ollama/ollama/x/imagegen/cache" "github.com/ollama/ollama/x/imagegen/mlx" "github.com/ollama/ollama/x/imagegen/tokenizer" diff --git a/x/imagegen/models/qwen_image_edit/rope_test.go b/x/imagegen/models/qwen_image_edit/rope_test.go index 7da4eaa3c..200940fbe 100644 --- a/x/imagegen/models/qwen_image_edit/rope_test.go +++ b/x/imagegen/models/qwen_image_edit/rope_test.go @@ -3,13 +3,35 @@ package qwen_image_edit import ( + "fmt" "math" + "os" + "path/filepath" + "runtime" "testing" "github.com/ollama/ollama/x/imagegen/mlx" "github.com/ollama/ollama/x/imagegen/models/qwen_image" ) +// TestMain initializes MLX before running tests. +// If MLX libraries are not available, tests are skipped. +func TestMain(m *testing.M) { + // Change to repo root so ./build/lib/ollama/ path works + _, thisFile, _, _ := runtime.Caller(0) + repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..", "..") + if err := os.Chdir(repoRoot); err != nil { + fmt.Printf("Failed to change to repo root: %v\n", err) + os.Exit(1) + } + + if err := mlx.InitMLX(); err != nil { + fmt.Printf("Skipping qwen_image_edit tests: %v\n", err) + os.Exit(0) + } + os.Exit(m.Run()) +} + // TestComputeAxisFreqs verifies frequency computation matches Python reference func TestComputeAxisFreqs(t *testing.T) { theta := float64(10000) diff --git a/x/imagegen/nn/nn_test.go b/x/imagegen/nn/nn_test.go index 2f8c04762..00e69ccb0 100644 --- a/x/imagegen/nn/nn_test.go +++ b/x/imagegen/nn/nn_test.go @@ -3,12 +3,34 @@ package nn import ( + "fmt" "math" + "os" + "path/filepath" + "runtime" "testing" "github.com/ollama/ollama/x/imagegen/mlx" ) +// TestMain initializes MLX before running tests. +// If MLX libraries are not available, tests are skipped. +func TestMain(m *testing.M) { + // Change to repo root so ./build/lib/ollama/ path works + _, thisFile, _, _ := runtime.Caller(0) + repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..") + if err := os.Chdir(repoRoot); err != nil { + fmt.Printf("Failed to change to repo root: %v\n", err) + os.Exit(1) + } + + if err := mlx.InitMLX(); err != nil { + fmt.Printf("Skipping nn tests: %v\n", err) + os.Exit(0) + } + os.Exit(m.Run()) +} + // TestLinearNoBias verifies Linear without bias computes x @ w.T correctly. func TestLinearNoBias(t *testing.T) { // Weight: [out=2, in=3] -> transposed at forward time diff --git a/x/imagegen/runner/runner.go b/x/imagegen/runner/runner.go index 6354ce234..ede11e765 100644 --- a/x/imagegen/runner/runner.go +++ b/x/imagegen/runner/runner.go @@ -62,6 +62,12 @@ func Execute(args []string) error { return fmt.Errorf("--port is required") } + err := mlx.InitMLX() + if err != nil { + slog.Error("unable to initialize MLX", "error", err) + return err + } + slog.Info("MLX library initialized") slog.Info("starting image runner", "model", *modelName, "port", *port) // Check memory requirements before loading diff --git a/x/imagegen/server.go b/x/imagegen/server.go index d6e1684d5..7c55ad77a 100644 --- a/x/imagegen/server.go +++ b/x/imagegen/server.go @@ -62,7 +62,7 @@ func NewServer(modelName string) (*Server, error) { port = rand.Intn(65535-49152) + 49152 } - // Get the ollama-mlx executable path (in same directory as current executable) + // Get the current executable path (we use the same binary with runner subcommand) exe, err := os.Executable() if err != nil { return nil, fmt.Errorf("unable to lookup executable path: %w", err) @@ -70,10 +70,9 @@ func NewServer(modelName string) (*Server, error) { if eval, err := filepath.EvalSymlinks(exe); err == nil { exe = eval } - mlxExe := filepath.Join(filepath.Dir(exe), "ollama-mlx") - // Spawn subprocess: ollama-mlx runner --image-engine --model --port - cmd := exec.Command(mlxExe, "runner", "--image-engine", "--model", modelName, "--port", strconv.Itoa(port)) + // Spawn subprocess: ollama runner --image-engine --model --port + cmd := exec.Command(exe, "runner", "--image-engine", "--model", modelName, "--port", strconv.Itoa(port)) cmd.Env = os.Environ() // On Linux, set LD_LIBRARY_PATH to include MLX library directories @@ -135,7 +134,7 @@ func NewServer(modelName string) (*Server, error) { } }() - slog.Info("starting ollama-mlx image runner subprocess", "exe", mlxExe, "model", modelName, "port", port) + slog.Info("starting image runner subprocess", "exe", exe, "model", modelName, "port", port) if err := cmd.Start(); err != nil { return nil, fmt.Errorf("failed to start image runner: %w", err) } diff --git a/x/ml/backend/mlx/CMakeLists.txt b/x/ml/backend/mlx/CMakeLists.txt index e71a6567a..b62cbf2eb 100644 --- a/x/ml/backend/mlx/CMakeLists.txt +++ b/x/ml/backend/mlx/CMakeLists.txt @@ -1,5 +1,9 @@ include(FetchContent) +# Read MLX version from top-level file (shared with Dockerfile) +file(READ "${CMAKE_SOURCE_DIR}/MLX_VERSION" MLX_C_GIT_TAG) +string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG) + set(MLX_C_BUILD_EXAMPLES OFF) set(MLX_BUILD_GGUF OFF) @@ -50,7 +54,7 @@ endif() FetchContent_Declare( mlx-c GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git" - GIT_TAG v0.4.1) + GIT_TAG ${MLX_C_GIT_TAG}) FetchContent_MakeAvailable(mlx-c) set_target_output_directory(mlx) diff --git a/x/ml/backend/mlx/mlx_dynamic.c b/x/ml/backend/mlx/mlx_dynamic.c new file mode 100644 index 000000000..0038355ae --- /dev/null +++ b/x/ml/backend/mlx/mlx_dynamic.c @@ -0,0 +1,92 @@ +// mlx_dynamic.c - Dynamic loading wrapper for MLX-C library +// This file provides runtime dynamic loading of libmlxc instead of link-time binding + +#include "mlx_dynamic.h" +#include +#include +#include + +#ifdef _WIN32 +#include +typedef HMODULE lib_handle_t; +#define LOAD_LIB(path) LoadLibraryA(path) +#define GET_SYMBOL(handle, name) GetProcAddress(handle, name) +#define CLOSE_LIB(handle) FreeLibrary(handle) +#define LIB_ERROR() "LoadLibrary failed" +static const char* LIB_NAMES[] = {"libmlxc.dll", NULL}; +#else +#include +typedef void* lib_handle_t; +#define LOAD_LIB(path) dlopen(path, RTLD_LAZY | RTLD_GLOBAL) +#define GET_SYMBOL(handle, name) dlsym(handle, name) +#define CLOSE_LIB(handle) dlclose(handle) +#define LIB_ERROR() dlerror() +#ifdef __APPLE__ +static const char* LIB_NAMES[] = { + "libmlxc.dylib", + "@loader_path/../build/lib/ollama/libmlxc.dylib", + "@executable_path/../build/lib/ollama/libmlxc.dylib", + "build/lib/ollama/libmlxc.dylib", + "../build/lib/ollama/libmlxc.dylib", + NULL +}; +#else +static const char* LIB_NAMES[] = { + "libmlxc.so", + "$ORIGIN/../build/lib/ollama/libmlxc.so", + "build/lib/ollama/libmlxc.so", + "../build/lib/ollama/libmlxc.so", + NULL +}; +#endif +#endif + +static lib_handle_t mlx_handle = NULL; +static int mlx_initialized = 0; +static char mlx_error_buffer[512] = {0}; + +// Initialize MLX dynamic library +// Returns 0 on success, -1 on failure +// On failure, call mlx_dynamic_error() to get error message +int mlx_dynamic_init(void) { + if (mlx_initialized) { + return 0; // Already initialized + } + + // Try each possible library path + for (int i = 0; LIB_NAMES[i] != NULL; i++) { + mlx_handle = LOAD_LIB(LIB_NAMES[i]); + if (mlx_handle != NULL) { + mlx_initialized = 1; + snprintf(mlx_error_buffer, sizeof(mlx_error_buffer), + "MLX: Successfully loaded %s", LIB_NAMES[i]); + return 0; + } + } + + // Failed to load library + const char* err = LIB_ERROR(); + snprintf(mlx_error_buffer, sizeof(mlx_error_buffer), + "MLX: Failed to load libmlxc library. %s", + err ? err : "Unknown error"); + return -1; +} + +// Get the last error message +const char* mlx_dynamic_error(void) { + return mlx_error_buffer; +} + +// Check if MLX is initialized +int mlx_dynamic_is_initialized(void) { + return mlx_initialized; +} + +// Cleanup (optional, called at program exit) +void mlx_dynamic_cleanup(void) { + if (mlx_handle != NULL) { + CLOSE_LIB(mlx_handle); + mlx_handle = NULL; + mlx_initialized = 0; + } +} diff --git a/x/ml/backend/mlx/mlx_dynamic.h b/x/ml/backend/mlx/mlx_dynamic.h new file mode 100644 index 000000000..2ae162a9a --- /dev/null +++ b/x/ml/backend/mlx/mlx_dynamic.h @@ -0,0 +1,26 @@ +// mlx_dynamic.h - Dynamic loading interface for MLX-C library +#ifndef MLX_DYNAMIC_H +#define MLX_DYNAMIC_H + +#ifdef __cplusplus +extern "C" { +#endif + +// Initialize the MLX dynamic library +// Returns 0 on success, -1 on failure +int mlx_dynamic_init(void); + +// Get the last error message from dynamic loading +const char* mlx_dynamic_error(void); + +// Check if MLX is initialized +int mlx_dynamic_is_initialized(void); + +// Cleanup resources (optional, for clean shutdown) +void mlx_dynamic_cleanup(void); + +#ifdef __cplusplus +} +#endif + +#endif // MLX_DYNAMIC_H