Files
ollama/x/imagegen/mlx/mlx.c
Daniel Hiltgen 12719b6e87 MLX - dynamic loading of mlx-c (#13735)
* MLX - dynamic loading of mlx-c

Create a wrapper layer to indirect the dependency on mlx-c so
the main ollama binary does not have a load-time dependency on mlx-c, mlx, and on linux, cuda.  Lazy load the library via dlopen
so we can adjust the path to ensure the dependencies are found
and fail gracefully if not present.

* review comments

* fix broken tests
2026-01-16 16:34:22 -08:00

5787 lines
269 KiB
C

// AUTO-GENERATED by generate_wrappers.go - DO NOT EDIT
// This file contains the function pointer definitions and initialization
// All function pointers are in a single compilation unit to avoid duplication
#include "mlx/c/mlx.h"
#include "mlx_dynamic.h"
#include <stdio.h>
#include <dlfcn.h>
// Function pointer definitions
size_t (*mlx_dtype_size_ptr)(mlx_dtype dtype) = NULL;
int (*mlx_array_tostring_ptr)(mlx_string* str, const mlx_array arr) = NULL;
mlx_array (*mlx_array_new_ptr)(void) = NULL;
int (*mlx_array_free_ptr)(mlx_array arr) = NULL;
mlx_array (*mlx_array_new_bool_ptr)(bool val) = NULL;
mlx_array (*mlx_array_new_int_ptr)(int val) = NULL;
mlx_array (*mlx_array_new_float32_ptr)(float val) = NULL;
mlx_array (*mlx_array_new_float_ptr)(float val) = NULL;
mlx_array (*mlx_array_new_float64_ptr)(double val) = NULL;
mlx_array (*mlx_array_new_double_ptr)(double val) = NULL;
mlx_array (*mlx_array_new_complex_ptr)(float real_val, float imag_val) = NULL;
mlx_array (*mlx_array_new_data_ptr)(const void* data, const int* shape, int dim, mlx_dtype dtype) = NULL;
int (*mlx_array_set_ptr)(mlx_array* arr, const mlx_array src) = NULL;
int (*mlx_array_set_bool_ptr)(mlx_array* arr, bool val) = NULL;
int (*mlx_array_set_int_ptr)(mlx_array* arr, int val) = NULL;
int (*mlx_array_set_float32_ptr)(mlx_array* arr, float val) = NULL;
int (*mlx_array_set_float_ptr)(mlx_array* arr, float val) = NULL;
int (*mlx_array_set_float64_ptr)(mlx_array* arr, double val) = NULL;
int (*mlx_array_set_double_ptr)(mlx_array* arr, double val) = NULL;
int (*mlx_array_set_complex_ptr)(mlx_array* arr, float real_val, float imag_val) = NULL;
int (*mlx_array_set_data_ptr)(mlx_array* arr, const void* data, const int* shape, int dim, mlx_dtype dtype) = NULL;
size_t (*mlx_array_itemsize_ptr)(const mlx_array arr) = NULL;
size_t (*mlx_array_size_ptr)(const mlx_array arr) = NULL;
size_t (*mlx_array_nbytes_ptr)(const mlx_array arr) = NULL;
size_t (*mlx_array_ndim_ptr)(const mlx_array arr) = NULL;
const int* (*mlx_array_shape_ptr)(const mlx_array arr) = NULL;
const size_t* (*mlx_array_strides_ptr)(const mlx_array arr) = NULL;
int (*mlx_array_dim_ptr)(const mlx_array arr, int dim) = NULL;
mlx_dtype (*mlx_array_dtype_ptr)(const mlx_array arr) = NULL;
int (*mlx_array_eval_ptr)(mlx_array arr) = NULL;
int (*mlx_array_item_bool_ptr)(bool* res, const mlx_array arr) = NULL;
int (*mlx_array_item_uint8_ptr)(uint8_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_uint16_ptr)(uint16_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_uint32_ptr)(uint32_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_uint64_ptr)(uint64_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_int8_ptr)(int8_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_int16_ptr)(int16_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_int32_ptr)(int32_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_int64_ptr)(int64_t* res, const mlx_array arr) = NULL;
int (*mlx_array_item_float32_ptr)(float* res, const mlx_array arr) = NULL;
int (*mlx_array_item_float64_ptr)(double* res, const mlx_array arr) = NULL;
int (*mlx_array_item_complex64_ptr)(float _Complex* res, const mlx_array arr) = NULL;
#if defined(__aarch64__) || defined(_M_ARM64)
int (*mlx_array_item_float16_ptr)(float16_t* res, const mlx_array arr) = NULL;
#endif
#if defined(__aarch64__) || defined(_M_ARM64)
int (*mlx_array_item_bfloat16_ptr)(bfloat16_t* res, const mlx_array arr) = NULL;
#endif
const bool* (*mlx_array_data_bool_ptr)(const mlx_array arr) = NULL;
const uint8_t* (*mlx_array_data_uint8_ptr)(const mlx_array arr) = NULL;
const uint16_t* (*mlx_array_data_uint16_ptr)(const mlx_array arr) = NULL;
const uint32_t* (*mlx_array_data_uint32_ptr)(const mlx_array arr) = NULL;
const uint64_t* (*mlx_array_data_uint64_ptr)(const mlx_array arr) = NULL;
const int8_t* (*mlx_array_data_int8_ptr)(const mlx_array arr) = NULL;
const int16_t* (*mlx_array_data_int16_ptr)(const mlx_array arr) = NULL;
const int32_t* (*mlx_array_data_int32_ptr)(const mlx_array arr) = NULL;
const int64_t* (*mlx_array_data_int64_ptr)(const mlx_array arr) = NULL;
const float* (*mlx_array_data_float32_ptr)(const mlx_array arr) = NULL;
const double* (*mlx_array_data_float64_ptr)(const mlx_array arr) = NULL;
const float _Complex* (*mlx_array_data_complex64_ptr)(const mlx_array arr) = NULL;
#if defined(__aarch64__) || defined(_M_ARM64)
const float16_t* (*mlx_array_data_float16_ptr)(const mlx_array arr) = NULL;
#endif
#if defined(__aarch64__) || defined(_M_ARM64)
const bfloat16_t* (*mlx_array_data_bfloat16_ptr)(const mlx_array arr) = NULL;
#endif
int (*_mlx_array_is_available_ptr)(bool* res, const mlx_array arr) = NULL;
int (*_mlx_array_wait_ptr)(const mlx_array arr) = NULL;
int (*_mlx_array_is_contiguous_ptr)(bool* res, const mlx_array arr) = NULL;
int (*_mlx_array_is_row_contiguous_ptr)(bool* res, const mlx_array arr) = NULL;
int (*_mlx_array_is_col_contiguous_ptr)(bool* res, const mlx_array arr) = NULL;
mlx_closure (*mlx_closure_new_ptr)(void) = NULL;
int (*mlx_closure_free_ptr)(mlx_closure cls) = NULL;
mlx_closure (*mlx_closure_new_func_ptr)(int (*fun)(mlx_vector_array*, const mlx_vector_array)) = NULL;
mlx_closure (*mlx_closure_new_func_payload_ptr)(int (*fun)(mlx_vector_array*, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)) = NULL;
int (*mlx_closure_set_ptr)(mlx_closure* cls, const mlx_closure src) = NULL;
int (*mlx_closure_apply_ptr)(mlx_vector_array* res, mlx_closure cls, const mlx_vector_array input) = NULL;
mlx_closure (*mlx_closure_new_unary_ptr)(int (*fun)(mlx_array*, const mlx_array)) = NULL;
mlx_closure_kwargs (*mlx_closure_kwargs_new_ptr)(void) = NULL;
int (*mlx_closure_kwargs_free_ptr)(mlx_closure_kwargs cls) = NULL;
mlx_closure_kwargs (*mlx_closure_kwargs_new_func_ptr)(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_map_string_to_array)) = NULL;
mlx_closure_kwargs (*mlx_closure_kwargs_new_func_payload_ptr)(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_map_string_to_array, void*), void* payload, void (*dtor)(void*)) = NULL;
int (*mlx_closure_kwargs_set_ptr)(mlx_closure_kwargs* cls, const mlx_closure_kwargs src) = NULL;
int (*mlx_closure_kwargs_apply_ptr)(mlx_vector_array* res, mlx_closure_kwargs cls, const mlx_vector_array input_0, const mlx_map_string_to_array input_1) = NULL;
mlx_closure_value_and_grad (*mlx_closure_value_and_grad_new_ptr)(void) = NULL;
int (*mlx_closure_value_and_grad_free_ptr)(mlx_closure_value_and_grad cls) = NULL;
mlx_closure_value_and_grad (*mlx_closure_value_and_grad_new_func_ptr)(int (*fun)(mlx_vector_array*, mlx_vector_array*, const mlx_vector_array)) = NULL;
mlx_closure_value_and_grad (*mlx_closure_value_and_grad_new_func_payload_ptr)(int (*fun)( mlx_vector_array*, mlx_vector_array*, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)) = NULL;
int (*mlx_closure_value_and_grad_set_ptr)(mlx_closure_value_and_grad* cls, const mlx_closure_value_and_grad src) = NULL;
int (*mlx_closure_value_and_grad_apply_ptr)(mlx_vector_array* res_0, mlx_vector_array* res_1, mlx_closure_value_and_grad cls, const mlx_vector_array input) = NULL;
mlx_closure_custom (*mlx_closure_custom_new_ptr)(void) = NULL;
int (*mlx_closure_custom_free_ptr)(mlx_closure_custom cls) = NULL;
mlx_closure_custom (*mlx_closure_custom_new_func_ptr)(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const mlx_vector_array)) = NULL;
mlx_closure_custom (*mlx_closure_custom_new_func_payload_ptr)(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)) = NULL;
int (*mlx_closure_custom_set_ptr)(mlx_closure_custom* cls, const mlx_closure_custom src) = NULL;
int (*mlx_closure_custom_apply_ptr)(mlx_vector_array* res, mlx_closure_custom cls, const mlx_vector_array input_0, const mlx_vector_array input_1, const mlx_vector_array input_2) = NULL;
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_ptr)(void) = NULL;
int (*mlx_closure_custom_jvp_free_ptr)(mlx_closure_custom_jvp cls) = NULL;
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_ptr)(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const int*, size_t _num)) = NULL;
mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_payload_ptr)(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const int*, size_t _num, void*), void* payload, void (*dtor)(void*)) = NULL;
int (*mlx_closure_custom_jvp_set_ptr)(mlx_closure_custom_jvp* cls, const mlx_closure_custom_jvp src) = NULL;
int (*mlx_closure_custom_jvp_apply_ptr)(mlx_vector_array* res, mlx_closure_custom_jvp cls, const mlx_vector_array input_0, const mlx_vector_array input_1, const int* input_2, size_t input_2_num) = NULL;
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_ptr)(void) = NULL;
int (*mlx_closure_custom_vmap_free_ptr)(mlx_closure_custom_vmap cls) = NULL;
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_ptr)(int (*fun)( mlx_vector_array*, mlx_vector_int*, const mlx_vector_array, const int*, size_t _num)) = NULL;
mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_payload_ptr)(int (*fun)( mlx_vector_array*, mlx_vector_int*, const mlx_vector_array, const int*, size_t _num, void*), void* payload, void (*dtor)(void*)) = NULL;
int (*mlx_closure_custom_vmap_set_ptr)(mlx_closure_custom_vmap* cls, const mlx_closure_custom_vmap src) = NULL;
int (*mlx_closure_custom_vmap_apply_ptr)(mlx_vector_array* res_0, mlx_vector_int* res_1, mlx_closure_custom_vmap cls, const mlx_vector_array input_0, const int* input_1, size_t input_1_num) = NULL;
int (*mlx_compile_ptr)(mlx_closure* res, const mlx_closure fun, bool shapeless) = NULL;
int (*mlx_detail_compile_ptr)(mlx_closure* res, const mlx_closure fun, uintptr_t fun_id, bool shapeless, const uint64_t* constants, size_t constants_num) = NULL;
int (*mlx_detail_compile_clear_cache_ptr)(void) = NULL;
int (*mlx_detail_compile_erase_ptr)(uintptr_t fun_id) = NULL;
int (*mlx_disable_compile_ptr)(void) = NULL;
int (*mlx_enable_compile_ptr)(void) = NULL;
int (*mlx_set_compile_mode_ptr)(mlx_compile_mode mode) = NULL;
mlx_device (*mlx_device_new_ptr)(void) = NULL;
mlx_device (*mlx_device_new_type_ptr)(mlx_device_type type, int index) = NULL;
int (*mlx_device_free_ptr)(mlx_device dev) = NULL;
int (*mlx_device_set_ptr)(mlx_device* dev, const mlx_device src) = NULL;
int (*mlx_device_tostring_ptr)(mlx_string* str, mlx_device dev) = NULL;
bool (*mlx_device_equal_ptr)(mlx_device lhs, mlx_device rhs) = NULL;
int (*mlx_device_get_index_ptr)(int* index, mlx_device dev) = NULL;
int (*mlx_device_get_type_ptr)(mlx_device_type* type, mlx_device dev) = NULL;
int (*mlx_get_default_device_ptr)(mlx_device* dev) = NULL;
int (*mlx_set_default_device_ptr)(mlx_device dev) = NULL;
int (*mlx_distributed_all_gather_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S) = NULL;
int (*mlx_distributed_all_max_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL;
int (*mlx_distributed_all_min_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL;
int (*mlx_distributed_all_sum_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL;
int (*mlx_distributed_recv_ptr)(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, int src, const mlx_distributed_group group , const mlx_stream s) = NULL;
int (*mlx_distributed_recv_like_ptr)(mlx_array* res, const mlx_array x, int src, const mlx_distributed_group group , const mlx_stream s) = NULL;
int (*mlx_distributed_send_ptr)(mlx_array* res, const mlx_array x, int dst, const mlx_distributed_group group , const mlx_stream s) = NULL;
int (*mlx_distributed_sum_scatter_ptr)(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) = NULL;
int (*mlx_distributed_group_rank_ptr)(mlx_distributed_group group) = NULL;
int (*mlx_distributed_group_size_ptr)(mlx_distributed_group group) = NULL;
mlx_distributed_group (*mlx_distributed_group_split_ptr)(mlx_distributed_group group, int color, int key) = NULL;
bool (*mlx_distributed_is_available_ptr)(void) = NULL;
mlx_distributed_group (*mlx_distributed_init_ptr)(bool strict) = NULL;
void (*mlx_set_error_handler_ptr)(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) = NULL;
void (*_mlx_error_ptr)(const char* file, const int line, const char* fmt, ...) = NULL;
int (*mlx_export_function_ptr)(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless) = NULL;
int (*mlx_export_function_kwargs_ptr)(const char* file, const mlx_closure_kwargs fun, const mlx_vector_array args, const mlx_map_string_to_array kwargs, bool shapeless) = NULL;
mlx_function_exporter (*mlx_function_exporter_new_ptr)(const char* file, const mlx_closure fun, bool shapeless) = NULL;
int (*mlx_function_exporter_free_ptr)(mlx_function_exporter xfunc) = NULL;
int (*mlx_function_exporter_apply_ptr)(const mlx_function_exporter xfunc, const mlx_vector_array args) = NULL;
int (*mlx_function_exporter_apply_kwargs_ptr)(const mlx_function_exporter xfunc, const mlx_vector_array args, const mlx_map_string_to_array kwargs) = NULL;
mlx_imported_function (*mlx_imported_function_new_ptr)(const char* file) = NULL;
int (*mlx_imported_function_free_ptr)(mlx_imported_function xfunc) = NULL;
int (*mlx_imported_function_apply_ptr)(mlx_vector_array* res, const mlx_imported_function xfunc, const mlx_vector_array args) = NULL;
int (*mlx_imported_function_apply_kwargs_ptr)(mlx_vector_array* res, const mlx_imported_function xfunc, const mlx_vector_array args, const mlx_map_string_to_array kwargs) = NULL;
mlx_fast_cuda_kernel_config (*mlx_fast_cuda_kernel_config_new_ptr)(void) = NULL;
void (*mlx_fast_cuda_kernel_config_free_ptr)(mlx_fast_cuda_kernel_config cls) = NULL;
int (*mlx_fast_cuda_kernel_config_add_output_arg_ptr)(mlx_fast_cuda_kernel_config cls, const int* shape, size_t size, mlx_dtype dtype) = NULL;
int (*mlx_fast_cuda_kernel_config_set_grid_ptr)(mlx_fast_cuda_kernel_config cls, int grid1, int grid2, int grid3) = NULL;
int (*mlx_fast_cuda_kernel_config_set_thread_group_ptr)(mlx_fast_cuda_kernel_config cls, int thread1, int thread2, int thread3) = NULL;
int (*mlx_fast_cuda_kernel_config_set_init_value_ptr)(mlx_fast_cuda_kernel_config cls, float value) = NULL;
int (*mlx_fast_cuda_kernel_config_set_verbose_ptr)(mlx_fast_cuda_kernel_config cls, bool verbose) = NULL;
int (*mlx_fast_cuda_kernel_config_add_template_arg_dtype_ptr)(mlx_fast_cuda_kernel_config cls, const char* name, mlx_dtype dtype) = NULL;
int (*mlx_fast_cuda_kernel_config_add_template_arg_int_ptr)(mlx_fast_cuda_kernel_config cls, const char* name, int value) = NULL;
int (*mlx_fast_cuda_kernel_config_add_template_arg_bool_ptr)(mlx_fast_cuda_kernel_config cls, const char* name, bool value) = NULL;
mlx_fast_cuda_kernel (*mlx_fast_cuda_kernel_new_ptr)(const char* name, const mlx_vector_string input_names, const mlx_vector_string output_names, const char* source, const char* header, bool ensure_row_contiguous, int shared_memory) = NULL;
void (*mlx_fast_cuda_kernel_free_ptr)(mlx_fast_cuda_kernel cls) = NULL;
int (*mlx_fast_cuda_kernel_apply_ptr)(mlx_vector_array* outputs, mlx_fast_cuda_kernel cls, const mlx_vector_array inputs, const mlx_fast_cuda_kernel_config config, const mlx_stream stream) = NULL;
int (*mlx_fast_layer_norm_ptr)(mlx_array* res, const mlx_array x, const mlx_array weight , const mlx_array bias , float eps, const mlx_stream s) = NULL;
mlx_fast_metal_kernel_config (*mlx_fast_metal_kernel_config_new_ptr)(void) = NULL;
void (*mlx_fast_metal_kernel_config_free_ptr)(mlx_fast_metal_kernel_config cls) = NULL;
int (*mlx_fast_metal_kernel_config_add_output_arg_ptr)(mlx_fast_metal_kernel_config cls, const int* shape, size_t size, mlx_dtype dtype) = NULL;
int (*mlx_fast_metal_kernel_config_set_grid_ptr)(mlx_fast_metal_kernel_config cls, int grid1, int grid2, int grid3) = NULL;
int (*mlx_fast_metal_kernel_config_set_thread_group_ptr)(mlx_fast_metal_kernel_config cls, int thread1, int thread2, int thread3) = NULL;
int (*mlx_fast_metal_kernel_config_set_init_value_ptr)(mlx_fast_metal_kernel_config cls, float value) = NULL;
int (*mlx_fast_metal_kernel_config_set_verbose_ptr)(mlx_fast_metal_kernel_config cls, bool verbose) = NULL;
int (*mlx_fast_metal_kernel_config_add_template_arg_dtype_ptr)(mlx_fast_metal_kernel_config cls, const char* name, mlx_dtype dtype) = NULL;
int (*mlx_fast_metal_kernel_config_add_template_arg_int_ptr)(mlx_fast_metal_kernel_config cls, const char* name, int value) = NULL;
int (*mlx_fast_metal_kernel_config_add_template_arg_bool_ptr)(mlx_fast_metal_kernel_config cls, const char* name, bool value) = NULL;
mlx_fast_metal_kernel (*mlx_fast_metal_kernel_new_ptr)(const char* name, const mlx_vector_string input_names, const mlx_vector_string output_names, const char* source, const char* header, bool ensure_row_contiguous, bool atomic_outputs) = NULL;
void (*mlx_fast_metal_kernel_free_ptr)(mlx_fast_metal_kernel cls) = NULL;
int (*mlx_fast_metal_kernel_apply_ptr)(mlx_vector_array* outputs, mlx_fast_metal_kernel cls, const mlx_vector_array inputs, const mlx_fast_metal_kernel_config config, const mlx_stream stream) = NULL;
int (*mlx_fast_rms_norm_ptr)(mlx_array* res, const mlx_array x, const mlx_array weight , float eps, const mlx_stream s) = NULL;
int (*mlx_fast_rope_ptr)(mlx_array* res, const mlx_array x, int dims, bool traditional, mlx_optional_float base, float scale, int offset, const mlx_array freqs , const mlx_stream s) = NULL;
int (*mlx_fast_rope_dynamic_ptr)(mlx_array* res, const mlx_array x, int dims, bool traditional, mlx_optional_float base, float scale, const mlx_array offset, const mlx_array freqs , const mlx_stream s) = NULL;
int (*mlx_fast_scaled_dot_product_attention_ptr)(mlx_array* res, const mlx_array queries, const mlx_array keys, const mlx_array values, float scale, const char* mask_mode, const mlx_array mask_arr , const mlx_array sinks , const mlx_stream s) = NULL;
int (*mlx_fft_fft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) = NULL;
int (*mlx_fft_fft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
int (*mlx_fft_fftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
int (*mlx_fft_fftshift_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
int (*mlx_fft_ifft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) = NULL;
int (*mlx_fft_ifft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
int (*mlx_fft_ifftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
int (*mlx_fft_ifftshift_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
int (*mlx_fft_irfft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) = NULL;
int (*mlx_fft_irfft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
int (*mlx_fft_irfftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
int (*mlx_fft_rfft_ptr)(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) = NULL;
int (*mlx_fft_rfft2_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
int (*mlx_fft_rfftn_ptr)(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
int (*mlx_load_reader_ptr)(mlx_array* res, mlx_io_reader in_stream, const mlx_stream s) = NULL;
int (*mlx_load_ptr)(mlx_array* res, const char* file, const mlx_stream s) = NULL;
int (*mlx_load_safetensors_reader_ptr)(mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, mlx_io_reader in_stream, const mlx_stream s) = NULL;
int (*mlx_load_safetensors_ptr)(mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, const char* file, const mlx_stream s) = NULL;
int (*mlx_save_writer_ptr)(mlx_io_writer out_stream, const mlx_array a) = NULL;
int (*mlx_save_ptr)(const char* file, const mlx_array a) = NULL;
int (*mlx_save_safetensors_writer_ptr)(mlx_io_writer in_stream, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata) = NULL;
int (*mlx_save_safetensors_ptr)(const char* file, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata) = NULL;
mlx_io_reader (*mlx_io_reader_new_ptr)(void* desc, mlx_io_vtable vtable) = NULL;
int (*mlx_io_reader_descriptor_ptr)(void** desc_, mlx_io_reader io) = NULL;
int (*mlx_io_reader_tostring_ptr)(mlx_string* str_, mlx_io_reader io) = NULL;
int (*mlx_io_reader_free_ptr)(mlx_io_reader io) = NULL;
mlx_io_writer (*mlx_io_writer_new_ptr)(void* desc, mlx_io_vtable vtable) = NULL;
int (*mlx_io_writer_descriptor_ptr)(void** desc_, mlx_io_writer io) = NULL;
int (*mlx_io_writer_tostring_ptr)(mlx_string* str_, mlx_io_writer io) = NULL;
int (*mlx_io_writer_free_ptr)(mlx_io_writer io) = NULL;
int (*mlx_linalg_cholesky_ptr)(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s) = NULL;
int (*mlx_linalg_cholesky_inv_ptr)(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s) = NULL;
int (*mlx_linalg_cross_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int axis, const mlx_stream s) = NULL;
int (*mlx_linalg_eig_ptr)(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_linalg_eigh_ptr)(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const char* UPLO, const mlx_stream s) = NULL;
int (*mlx_linalg_eigvals_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_linalg_eigvalsh_ptr)(mlx_array* res, const mlx_array a, const char* UPLO, const mlx_stream s) = NULL;
int (*mlx_linalg_inv_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_linalg_lu_ptr)(mlx_vector_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_linalg_lu_factor_ptr)(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_linalg_norm_ptr)(mlx_array* res, const mlx_array a, double ord, const int* axis , size_t axis_num, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_linalg_norm_matrix_ptr)(mlx_array* res, const mlx_array a, const char* ord, const int* axis , size_t axis_num, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_linalg_norm_l2_ptr)(mlx_array* res, const mlx_array a, const int* axis , size_t axis_num, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_linalg_pinv_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_linalg_qr_ptr)(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_linalg_solve_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_linalg_solve_triangular_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, bool upper, const mlx_stream s) = NULL;
int (*mlx_linalg_svd_ptr)(mlx_vector_array* res, const mlx_array a, bool compute_uv, const mlx_stream s) = NULL;
int (*mlx_linalg_tri_inv_ptr)(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s) = NULL;
mlx_map_string_to_array (*mlx_map_string_to_array_new_ptr)(void) = NULL;
int (*mlx_map_string_to_array_set_ptr)(mlx_map_string_to_array* map, const mlx_map_string_to_array src) = NULL;
int (*mlx_map_string_to_array_free_ptr)(mlx_map_string_to_array map) = NULL;
int (*mlx_map_string_to_array_insert_ptr)(mlx_map_string_to_array map, const char* key, const mlx_array value) = NULL;
int (*mlx_map_string_to_array_get_ptr)(mlx_array* value, const mlx_map_string_to_array map, const char* key) = NULL;
mlx_map_string_to_array_iterator (*mlx_map_string_to_array_iterator_new_ptr)(mlx_map_string_to_array map) = NULL;
int (*mlx_map_string_to_array_iterator_free_ptr)(mlx_map_string_to_array_iterator it) = NULL;
int (*mlx_map_string_to_array_iterator_next_ptr)(const char** key, mlx_array* value, mlx_map_string_to_array_iterator it) = NULL;
mlx_map_string_to_string (*mlx_map_string_to_string_new_ptr)(void) = NULL;
int (*mlx_map_string_to_string_set_ptr)(mlx_map_string_to_string* map, const mlx_map_string_to_string src) = NULL;
int (*mlx_map_string_to_string_free_ptr)(mlx_map_string_to_string map) = NULL;
int (*mlx_map_string_to_string_insert_ptr)(mlx_map_string_to_string map, const char* key, const char* value) = NULL;
int (*mlx_map_string_to_string_get_ptr)(const char** value, const mlx_map_string_to_string map, const char* key) = NULL;
mlx_map_string_to_string_iterator (*mlx_map_string_to_string_iterator_new_ptr)(mlx_map_string_to_string map) = NULL;
int (*mlx_map_string_to_string_iterator_free_ptr)(mlx_map_string_to_string_iterator it) = NULL;
int (*mlx_map_string_to_string_iterator_next_ptr)(const char** key, const char** value, mlx_map_string_to_string_iterator it) = NULL;
int (*mlx_clear_cache_ptr)(void) = NULL;
int (*mlx_get_active_memory_ptr)(size_t* res) = NULL;
int (*mlx_get_cache_memory_ptr)(size_t* res) = NULL;
int (*mlx_get_memory_limit_ptr)(size_t* res) = NULL;
int (*mlx_get_peak_memory_ptr)(size_t* res) = NULL;
int (*mlx_reset_peak_memory_ptr)(void) = NULL;
int (*mlx_set_cache_limit_ptr)(size_t* res, size_t limit) = NULL;
int (*mlx_set_memory_limit_ptr)(size_t* res, size_t limit) = NULL;
int (*mlx_set_wired_limit_ptr)(size_t* res, size_t limit) = NULL;
mlx_metal_device_info_t (*mlx_metal_device_info_ptr)(void) = NULL;
int (*mlx_metal_is_available_ptr)(bool* res) = NULL;
int (*mlx_metal_start_capture_ptr)(const char* path) = NULL;
int (*mlx_metal_stop_capture_ptr)(void) = NULL;
int (*mlx_abs_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_add_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_addmm_ptr)(mlx_array* res, const mlx_array c, const mlx_array a, const mlx_array b, float alpha, float beta, const mlx_stream s) = NULL;
int (*mlx_all_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_all_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_all_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_allclose_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, double rtol, double atol, bool equal_nan, const mlx_stream s) = NULL;
int (*mlx_any_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_any_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_any_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_arange_ptr)(mlx_array* res, double start, double stop, double step, mlx_dtype dtype, const mlx_stream s) = NULL;
int (*mlx_arccos_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_arccosh_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_arcsin_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_arcsinh_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_arctan_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_arctan2_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_arctanh_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_argmax_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_argmax_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_argmin_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_argmin_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_argpartition_axis_ptr)(mlx_array* res, const mlx_array a, int kth, int axis, const mlx_stream s) = NULL;
int (*mlx_argpartition_ptr)(mlx_array* res, const mlx_array a, int kth, const mlx_stream s) = NULL;
int (*mlx_argsort_axis_ptr)(mlx_array* res, const mlx_array a, int axis, const mlx_stream s) = NULL;
int (*mlx_argsort_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_array_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, bool equal_nan, const mlx_stream s) = NULL;
int (*mlx_as_strided_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const int64_t* strides, size_t strides_num, size_t offset, const mlx_stream s) = NULL;
int (*mlx_astype_ptr)(mlx_array* res, const mlx_array a, mlx_dtype dtype, const mlx_stream s) = NULL;
int (*mlx_atleast_1d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_atleast_2d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_atleast_3d_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_bitwise_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_bitwise_invert_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_bitwise_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_bitwise_xor_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_block_masked_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) = NULL;
int (*mlx_broadcast_arrays_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s) = NULL;
int (*mlx_broadcast_to_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s) = NULL;
int (*mlx_ceil_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_clip_ptr)(mlx_array* res, const mlx_array a, const mlx_array a_min , const mlx_array a_max , const mlx_stream s) = NULL;
int (*mlx_concatenate_axis_ptr)(mlx_array* res, const mlx_vector_array arrays, int axis, const mlx_stream s) = NULL;
int (*mlx_concatenate_ptr)(mlx_array* res, const mlx_vector_array arrays, const mlx_stream s) = NULL;
int (*mlx_conjugate_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_contiguous_ptr)(mlx_array* res, const mlx_array a, bool allow_col_major, const mlx_stream s) = NULL;
int (*mlx_conv1d_ptr)(mlx_array* res, const mlx_array input, const mlx_array weight, int stride, int padding, int dilation, int groups, const mlx_stream s) = NULL;
int (*mlx_conv2d_ptr)(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int padding_0, int padding_1, int dilation_0, int dilation_1, int groups, const mlx_stream s) = NULL;
int (*mlx_conv3d_ptr)(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int stride_2, int padding_0, int padding_1, int padding_2, int dilation_0, int dilation_1, int dilation_2, int groups, const mlx_stream s) = NULL;
int (*mlx_conv_general_ptr)(mlx_array* res, const mlx_array input, const mlx_array weight, const int* stride, size_t stride_num, const int* padding_lo, size_t padding_lo_num, const int* padding_hi, size_t padding_hi_num, const int* kernel_dilation, size_t kernel_dilation_num, const int* input_dilation, size_t input_dilation_num, int groups, bool flip, const mlx_stream s) = NULL;
int (*mlx_conv_transpose1d_ptr)(mlx_array* res, const mlx_array input, const mlx_array weight, int stride, int padding, int dilation, int output_padding, int groups, const mlx_stream s) = NULL;
int (*mlx_conv_transpose2d_ptr)(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int padding_0, int padding_1, int dilation_0, int dilation_1, int output_padding_0, int output_padding_1, int groups, const mlx_stream s) = NULL;
int (*mlx_conv_transpose3d_ptr)(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int stride_2, int padding_0, int padding_1, int padding_2, int dilation_0, int dilation_1, int dilation_2, int output_padding_0, int output_padding_1, int output_padding_2, int groups, const mlx_stream s) = NULL;
int (*mlx_copy_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_cos_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_cosh_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_cummax_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) = NULL;
int (*mlx_cummin_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) = NULL;
int (*mlx_cumprod_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) = NULL;
int (*mlx_cumsum_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) = NULL;
int (*mlx_degrees_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_depends_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies) = NULL;
int (*mlx_dequantize_ptr)(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s) = NULL;
int (*mlx_diag_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL;
int (*mlx_diagonal_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s) = NULL;
int (*mlx_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_divmod_ptr)(mlx_vector_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_einsum_ptr)(mlx_array* res, const char* subscripts, const mlx_vector_array operands, const mlx_stream s) = NULL;
int (*mlx_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_erf_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_erfinv_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_exp_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_expand_dims_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
int (*mlx_expand_dims_ptr)(mlx_array* res, const mlx_array a, int axis, const mlx_stream s) = NULL;
int (*mlx_expm1_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_eye_ptr)(mlx_array* res, int n, int m, int k, mlx_dtype dtype, const mlx_stream s) = NULL;
int (*mlx_flatten_ptr)(mlx_array* res, const mlx_array a, int start_axis, int end_axis, const mlx_stream s) = NULL;
int (*mlx_floor_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_floor_divide_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_from_fp8_ptr)(mlx_array* res, const mlx_array x, mlx_dtype dtype, const mlx_stream s) = NULL;
int (*mlx_full_ptr)(mlx_array* res, const int* shape, size_t shape_num, const mlx_array vals, mlx_dtype dtype, const mlx_stream s) = NULL;
int (*mlx_full_like_ptr)(mlx_array* res, const mlx_array a, const mlx_array vals, mlx_dtype dtype, const mlx_stream s) = NULL;
int (*mlx_gather_ptr)(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const int* axes, size_t axes_num, const int* slice_sizes, size_t slice_sizes_num, const mlx_stream s) = NULL;
int (*mlx_gather_single_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const int* slice_sizes, size_t slice_sizes_num, const mlx_stream s) = NULL;
int (*mlx_gather_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_array lhs_indices , const mlx_array rhs_indices , bool sorted_indices, const mlx_stream s) = NULL;
int (*mlx_gather_qmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , const mlx_array lhs_indices , const mlx_array rhs_indices , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, bool sorted_indices, const mlx_stream s) = NULL;
int (*mlx_greater_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_greater_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_hadamard_transform_ptr)(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s) = NULL;
int (*mlx_identity_ptr)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) = NULL;
int (*mlx_imag_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_inner_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_isclose_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, double rtol, double atol, bool equal_nan, const mlx_stream s) = NULL;
int (*mlx_isfinite_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_isinf_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_isnan_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_isneginf_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_isposinf_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_kron_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_left_shift_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_less_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_less_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_linspace_ptr)(mlx_array* res, double start, double stop, int num, mlx_dtype dtype, const mlx_stream s) = NULL;
int (*mlx_log_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_log10_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_log1p_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_log2_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_logaddexp_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_logcumsumexp_ptr)(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) = NULL;
int (*mlx_logical_and_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_logical_not_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_logical_or_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_logsumexp_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_logsumexp_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_logsumexp_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_masked_scatter_ptr)(mlx_array* res, const mlx_array a, const mlx_array mask, const mlx_array src, const mlx_stream s) = NULL;
int (*mlx_matmul_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_max_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_max_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_max_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_maximum_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_mean_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_mean_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_mean_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_median_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_meshgrid_ptr)(mlx_vector_array* res, const mlx_vector_array arrays, bool sparse, const char* indexing, const mlx_stream s) = NULL;
int (*mlx_min_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_min_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_min_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_minimum_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_moveaxis_ptr)(mlx_array* res, const mlx_array a, int source, int destination, const mlx_stream s) = NULL;
int (*mlx_multiply_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_nan_to_num_ptr)(mlx_array* res, const mlx_array a, float nan, mlx_optional_float posinf, mlx_optional_float neginf, const mlx_stream s) = NULL;
int (*mlx_negative_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_not_equal_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_number_of_elements_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool inverted, mlx_dtype dtype, const mlx_stream s) = NULL;
int (*mlx_ones_ptr)(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_stream s) = NULL;
int (*mlx_ones_like_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_outer_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_pad_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const int* low_pad_size, size_t low_pad_size_num, const int* high_pad_size, size_t high_pad_size_num, const mlx_array pad_value, const char* mode, const mlx_stream s) = NULL;
int (*mlx_pad_symmetric_ptr)(mlx_array* res, const mlx_array a, int pad_width, const mlx_array pad_value, const char* mode, const mlx_stream s) = NULL;
int (*mlx_partition_axis_ptr)(mlx_array* res, const mlx_array a, int kth, int axis, const mlx_stream s) = NULL;
int (*mlx_partition_ptr)(mlx_array* res, const mlx_array a, int kth, const mlx_stream s) = NULL;
int (*mlx_power_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_prod_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_prod_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_prod_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_put_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s) = NULL;
int (*mlx_qqmm_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL;
int (*mlx_quantize_ptr)(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL;
int (*mlx_quantized_matmul_ptr)(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) = NULL;
int (*mlx_radians_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_real_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_reciprocal_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_remainder_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_repeat_axis_ptr)(mlx_array* res, const mlx_array arr, int repeats, int axis, const mlx_stream s) = NULL;
int (*mlx_repeat_ptr)(mlx_array* res, const mlx_array arr, int repeats, const mlx_stream s) = NULL;
int (*mlx_reshape_ptr)(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s) = NULL;
int (*mlx_right_shift_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_roll_axis_ptr)(mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, int axis, const mlx_stream s) = NULL;
int (*mlx_roll_axes_ptr)(mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
int (*mlx_roll_ptr)(mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, const mlx_stream s) = NULL;
int (*mlx_round_ptr)(mlx_array* res, const mlx_array a, int decimals, const mlx_stream s) = NULL;
int (*mlx_rsqrt_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_scatter_ptr)(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
int (*mlx_scatter_single_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) = NULL;
int (*mlx_scatter_add_ptr)(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
int (*mlx_scatter_add_single_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) = NULL;
int (*mlx_scatter_add_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s) = NULL;
int (*mlx_scatter_max_ptr)(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
int (*mlx_scatter_max_single_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) = NULL;
int (*mlx_scatter_min_ptr)(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
int (*mlx_scatter_min_single_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) = NULL;
int (*mlx_scatter_prod_ptr)(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
int (*mlx_scatter_prod_single_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) = NULL;
int (*mlx_segmented_mm_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_array segments, const mlx_stream s) = NULL;
int (*mlx_sigmoid_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_sign_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_sin_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_sinh_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_slice_ptr)(mlx_array* res, const mlx_array a, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s) = NULL;
int (*mlx_slice_dynamic_ptr)(mlx_array* res, const mlx_array a, const mlx_array start, const int* axes, size_t axes_num, const int* slice_size, size_t slice_size_num, const mlx_stream s) = NULL;
int (*mlx_slice_update_ptr)(mlx_array* res, const mlx_array src, const mlx_array update, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s) = NULL;
int (*mlx_slice_update_dynamic_ptr)(mlx_array* res, const mlx_array src, const mlx_array update, const mlx_array start, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
int (*mlx_softmax_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool precise, const mlx_stream s) = NULL;
int (*mlx_softmax_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool precise, const mlx_stream s) = NULL;
int (*mlx_softmax_ptr)(mlx_array* res, const mlx_array a, bool precise, const mlx_stream s) = NULL;
int (*mlx_sort_axis_ptr)(mlx_array* res, const mlx_array a, int axis, const mlx_stream s) = NULL;
int (*mlx_sort_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_split_ptr)(mlx_vector_array* res, const mlx_array a, int num_splits, int axis, const mlx_stream s) = NULL;
int (*mlx_split_sections_ptr)(mlx_vector_array* res, const mlx_array a, const int* indices, size_t indices_num, int axis, const mlx_stream s) = NULL;
int (*mlx_sqrt_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_square_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_squeeze_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
int (*mlx_squeeze_axis_ptr)(mlx_array* res, const mlx_array a, int axis, const mlx_stream s) = NULL;
int (*mlx_squeeze_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_stack_axis_ptr)(mlx_array* res, const mlx_vector_array arrays, int axis, const mlx_stream s) = NULL;
int (*mlx_stack_ptr)(mlx_array* res, const mlx_vector_array arrays, const mlx_stream s) = NULL;
int (*mlx_std_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, int ddof, const mlx_stream s) = NULL;
int (*mlx_std_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, int ddof, const mlx_stream s) = NULL;
int (*mlx_std_ptr)(mlx_array* res, const mlx_array a, bool keepdims, int ddof, const mlx_stream s) = NULL;
int (*mlx_stop_gradient_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_subtract_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) = NULL;
int (*mlx_sum_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_sum_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_sum_ptr)(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) = NULL;
int (*mlx_swapaxes_ptr)(mlx_array* res, const mlx_array a, int axis1, int axis2, const mlx_stream s) = NULL;
int (*mlx_take_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const mlx_stream s) = NULL;
int (*mlx_take_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_stream s) = NULL;
int (*mlx_take_along_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const mlx_stream s) = NULL;
int (*mlx_tan_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_tanh_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_tensordot_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, const int* axes_a, size_t axes_a_num, const int* axes_b, size_t axes_b_num, const mlx_stream s) = NULL;
int (*mlx_tensordot_axis_ptr)(mlx_array* res, const mlx_array a, const mlx_array b, int axis, const mlx_stream s) = NULL;
int (*mlx_tile_ptr)(mlx_array* res, const mlx_array arr, const int* reps, size_t reps_num, const mlx_stream s) = NULL;
int (*mlx_to_fp8_ptr)(mlx_array* res, const mlx_array x, const mlx_stream s) = NULL;
int (*mlx_topk_axis_ptr)(mlx_array* res, const mlx_array a, int k, int axis, const mlx_stream s) = NULL;
int (*mlx_topk_ptr)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL;
int (*mlx_trace_ptr)(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, mlx_dtype dtype, const mlx_stream s) = NULL;
int (*mlx_transpose_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) = NULL;
int (*mlx_transpose_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_tri_ptr)(mlx_array* res, int n, int m, int k, mlx_dtype type, const mlx_stream s) = NULL;
int (*mlx_tril_ptr)(mlx_array* res, const mlx_array x, int k, const mlx_stream s) = NULL;
int (*mlx_triu_ptr)(mlx_array* res, const mlx_array x, int k, const mlx_stream s) = NULL;
int (*mlx_unflatten_ptr)(mlx_array* res, const mlx_array a, int axis, const int* shape, size_t shape_num, const mlx_stream s) = NULL;
int (*mlx_var_axes_ptr)(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, int ddof, const mlx_stream s) = NULL;
int (*mlx_var_axis_ptr)(mlx_array* res, const mlx_array a, int axis, bool keepdims, int ddof, const mlx_stream s) = NULL;
int (*mlx_var_ptr)(mlx_array* res, const mlx_array a, bool keepdims, int ddof, const mlx_stream s) = NULL;
int (*mlx_view_ptr)(mlx_array* res, const mlx_array a, mlx_dtype dtype, const mlx_stream s) = NULL;
int (*mlx_where_ptr)(mlx_array* res, const mlx_array condition, const mlx_array x, const mlx_array y, const mlx_stream s) = NULL;
int (*mlx_zeros_ptr)(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_stream s) = NULL;
int (*mlx_zeros_like_ptr)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
int (*mlx_random_bernoulli_ptr)(mlx_array* res, const mlx_array p, const int* shape, size_t shape_num, const mlx_array key , const mlx_stream s) = NULL;
int (*mlx_random_bits_ptr)(mlx_array* res, const int* shape, size_t shape_num, int width, const mlx_array key , const mlx_stream s) = NULL;
int (*mlx_random_categorical_shape_ptr)(mlx_array* res, const mlx_array logits, int axis, const int* shape, size_t shape_num, const mlx_array key , const mlx_stream s) = NULL;
int (*mlx_random_categorical_num_samples_ptr)(mlx_array* res, const mlx_array logits_, int axis, int num_samples, const mlx_array key , const mlx_stream s) = NULL;
int (*mlx_random_categorical_ptr)(mlx_array* res, const mlx_array logits, int axis, const mlx_array key , const mlx_stream s) = NULL;
int (*mlx_random_gumbel_ptr)(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) = NULL;
int (*mlx_random_key_ptr)(mlx_array* res, uint64_t seed) = NULL;
int (*mlx_random_laplace_ptr)(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, float loc, float scale, const mlx_array key , const mlx_stream s) = NULL;
int (*mlx_random_multivariate_normal_ptr)(mlx_array* res, const mlx_array mean, const mlx_array cov, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) = NULL;
int (*mlx_random_normal_broadcast_ptr)(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array loc , const mlx_array scale , const mlx_array key , const mlx_stream s) = NULL;
int (*mlx_random_normal_ptr)(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, float loc, float scale, const mlx_array key , const mlx_stream s) = NULL;
int (*mlx_random_permutation_ptr)(mlx_array* res, const mlx_array x, int axis, const mlx_array key , const mlx_stream s) = NULL;
int (*mlx_random_permutation_arange_ptr)(mlx_array* res, int x, const mlx_array key , const mlx_stream s) = NULL;
int (*mlx_random_randint_ptr)(mlx_array* res, const mlx_array low, const mlx_array high, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) = NULL;
int (*mlx_random_seed_ptr)(uint64_t seed) = NULL;
int (*mlx_random_split_num_ptr)(mlx_array* res, const mlx_array key, int num, const mlx_stream s) = NULL;
int (*mlx_random_split_ptr)(mlx_array* res_0, mlx_array* res_1, const mlx_array key, const mlx_stream s) = NULL;
int (*mlx_random_truncated_normal_ptr)(mlx_array* res, const mlx_array lower, const mlx_array upper, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) = NULL;
int (*mlx_random_uniform_ptr)(mlx_array* res, const mlx_array low, const mlx_array high, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) = NULL;
mlx_stream (*mlx_stream_new_ptr)(void) = NULL;
mlx_stream (*mlx_stream_new_device_ptr)(mlx_device dev) = NULL;
int (*mlx_stream_set_ptr)(mlx_stream* stream, const mlx_stream src) = NULL;
int (*mlx_stream_free_ptr)(mlx_stream stream) = NULL;
int (*mlx_stream_tostring_ptr)(mlx_string* str, mlx_stream stream) = NULL;
bool (*mlx_stream_equal_ptr)(mlx_stream lhs, mlx_stream rhs) = NULL;
int (*mlx_stream_get_device_ptr)(mlx_device* dev, mlx_stream stream) = NULL;
int (*mlx_stream_get_index_ptr)(int* index, mlx_stream stream) = NULL;
int (*mlx_synchronize_ptr)(mlx_stream stream) = NULL;
int (*mlx_get_default_stream_ptr)(mlx_stream* stream, mlx_device dev) = NULL;
int (*mlx_set_default_stream_ptr)(mlx_stream stream) = NULL;
mlx_stream (*mlx_default_cpu_stream_new_ptr)(void) = NULL;
mlx_stream (*mlx_default_gpu_stream_new_ptr)(void) = NULL;
mlx_string (*mlx_string_new_ptr)(void) = NULL;
mlx_string (*mlx_string_new_data_ptr)(const char* str) = NULL;
int (*mlx_string_set_ptr)(mlx_string* str, const mlx_string src) = NULL;
const char* (*mlx_string_data_ptr)(mlx_string str) = NULL;
int (*mlx_string_free_ptr)(mlx_string str) = NULL;
int (*mlx_async_eval_ptr)(const mlx_vector_array outputs) = NULL;
int (*mlx_checkpoint_ptr)(mlx_closure* res, const mlx_closure fun) = NULL;
int (*mlx_custom_function_ptr)(mlx_closure* res, const mlx_closure fun, const mlx_closure_custom fun_vjp , const mlx_closure_custom_jvp fun_jvp , const mlx_closure_custom_vmap fun_vmap) = NULL;
int (*mlx_custom_vjp_ptr)(mlx_closure* res, const mlx_closure fun, const mlx_closure_custom fun_vjp) = NULL;
int (*mlx_eval_ptr)(const mlx_vector_array outputs) = NULL;
int (*mlx_jvp_ptr)(mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array primals, const mlx_vector_array tangents) = NULL;
int (*mlx_value_and_grad_ptr)(mlx_closure_value_and_grad* res, const mlx_closure fun, const int* argnums, size_t argnums_num) = NULL;
int (*mlx_vjp_ptr)(mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array primals, const mlx_vector_array cotangents) = NULL;
int (*mlx_detail_vmap_replace_ptr)(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array s_inputs, const mlx_vector_array s_outputs, const int* in_axes, size_t in_axes_num, const int* out_axes, size_t out_axes_num) = NULL;
int (*mlx_detail_vmap_trace_ptr)(mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array inputs, const int* in_axes, size_t in_axes_num) = NULL;
mlx_vector_array (*mlx_vector_array_new_ptr)(void) = NULL;
int (*mlx_vector_array_set_ptr)(mlx_vector_array* vec, const mlx_vector_array src) = NULL;
int (*mlx_vector_array_free_ptr)(mlx_vector_array vec) = NULL;
mlx_vector_array (*mlx_vector_array_new_data_ptr)(const mlx_array* data, size_t size) = NULL;
mlx_vector_array (*mlx_vector_array_new_value_ptr)(const mlx_array val) = NULL;
int (*mlx_vector_array_set_data_ptr)(mlx_vector_array* vec, const mlx_array* data, size_t size) = NULL;
int (*mlx_vector_array_set_value_ptr)(mlx_vector_array* vec, const mlx_array val) = NULL;
int (*mlx_vector_array_append_data_ptr)(mlx_vector_array vec, const mlx_array* data, size_t size) = NULL;
int (*mlx_vector_array_append_value_ptr)(mlx_vector_array vec, const mlx_array val) = NULL;
size_t (*mlx_vector_array_size_ptr)(mlx_vector_array vec) = NULL;
int (*mlx_vector_array_get_ptr)(mlx_array* res, const mlx_vector_array vec, size_t idx) = NULL;
mlx_vector_vector_array (*mlx_vector_vector_array_new_ptr)(void) = NULL;
int (*mlx_vector_vector_array_set_ptr)(mlx_vector_vector_array* vec, const mlx_vector_vector_array src) = NULL;
int (*mlx_vector_vector_array_free_ptr)(mlx_vector_vector_array vec) = NULL;
mlx_vector_vector_array (*mlx_vector_vector_array_new_data_ptr)(const mlx_vector_array* data, size_t size) = NULL;
mlx_vector_vector_array (*mlx_vector_vector_array_new_value_ptr)(const mlx_vector_array val) = NULL;
int (*mlx_vector_vector_array_set_data_ptr)(mlx_vector_vector_array* vec, const mlx_vector_array* data, size_t size) = NULL;
int (*mlx_vector_vector_array_set_value_ptr)(mlx_vector_vector_array* vec, const mlx_vector_array val) = NULL;
int (*mlx_vector_vector_array_append_data_ptr)(mlx_vector_vector_array vec, const mlx_vector_array* data, size_t size) = NULL;
int (*mlx_vector_vector_array_append_value_ptr)(mlx_vector_vector_array vec, const mlx_vector_array val) = NULL;
size_t (*mlx_vector_vector_array_size_ptr)(mlx_vector_vector_array vec) = NULL;
int (*mlx_vector_vector_array_get_ptr)(mlx_vector_array* res, const mlx_vector_vector_array vec, size_t idx) = NULL;
mlx_vector_int (*mlx_vector_int_new_ptr)(void) = NULL;
int (*mlx_vector_int_set_ptr)(mlx_vector_int* vec, const mlx_vector_int src) = NULL;
int (*mlx_vector_int_free_ptr)(mlx_vector_int vec) = NULL;
mlx_vector_int (*mlx_vector_int_new_data_ptr)(int* data, size_t size) = NULL;
mlx_vector_int (*mlx_vector_int_new_value_ptr)(int val) = NULL;
int (*mlx_vector_int_set_data_ptr)(mlx_vector_int* vec, int* data, size_t size) = NULL;
int (*mlx_vector_int_set_value_ptr)(mlx_vector_int* vec, int val) = NULL;
int (*mlx_vector_int_append_data_ptr)(mlx_vector_int vec, int* data, size_t size) = NULL;
int (*mlx_vector_int_append_value_ptr)(mlx_vector_int vec, int val) = NULL;
size_t (*mlx_vector_int_size_ptr)(mlx_vector_int vec) = NULL;
int (*mlx_vector_int_get_ptr)(int* res, const mlx_vector_int vec, size_t idx) = NULL;
mlx_vector_string (*mlx_vector_string_new_ptr)(void) = NULL;
int (*mlx_vector_string_set_ptr)(mlx_vector_string* vec, const mlx_vector_string src) = NULL;
int (*mlx_vector_string_free_ptr)(mlx_vector_string vec) = NULL;
mlx_vector_string (*mlx_vector_string_new_data_ptr)(const char** data, size_t size) = NULL;
mlx_vector_string (*mlx_vector_string_new_value_ptr)(const char* val) = NULL;
int (*mlx_vector_string_set_data_ptr)(mlx_vector_string* vec, const char** data, size_t size) = NULL;
int (*mlx_vector_string_set_value_ptr)(mlx_vector_string* vec, const char* val) = NULL;
int (*mlx_vector_string_append_data_ptr)(mlx_vector_string vec, const char** data, size_t size) = NULL;
int (*mlx_vector_string_append_value_ptr)(mlx_vector_string vec, const char* val) = NULL;
size_t (*mlx_vector_string_size_ptr)(mlx_vector_string vec) = NULL;
int (*mlx_vector_string_get_ptr)(char** res, const mlx_vector_string vec, size_t idx) = NULL;
int (*mlx_version_ptr)(mlx_string* str_) = NULL;
// Initialize all function pointers via dlsym
int mlx_load_functions(void* handle) {
if (handle == NULL) {
fprintf(stderr, "MLX: Invalid library handle\n");
return -1;
}
mlx_dtype_size_ptr = dlsym(handle, "mlx_dtype_size");
if (mlx_dtype_size_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_dtype_size\n");
return -1;
}
mlx_array_tostring_ptr = dlsym(handle, "mlx_array_tostring");
if (mlx_array_tostring_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_tostring\n");
return -1;
}
mlx_array_new_ptr = dlsym(handle, "mlx_array_new");
if (mlx_array_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new\n");
return -1;
}
mlx_array_free_ptr = dlsym(handle, "mlx_array_free");
if (mlx_array_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_free\n");
return -1;
}
mlx_array_new_bool_ptr = dlsym(handle, "mlx_array_new_bool");
if (mlx_array_new_bool_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_bool\n");
return -1;
}
mlx_array_new_int_ptr = dlsym(handle, "mlx_array_new_int");
if (mlx_array_new_int_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_int\n");
return -1;
}
mlx_array_new_float32_ptr = dlsym(handle, "mlx_array_new_float32");
if (mlx_array_new_float32_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_float32\n");
return -1;
}
mlx_array_new_float_ptr = dlsym(handle, "mlx_array_new_float");
if (mlx_array_new_float_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_float\n");
return -1;
}
mlx_array_new_float64_ptr = dlsym(handle, "mlx_array_new_float64");
if (mlx_array_new_float64_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_float64\n");
return -1;
}
mlx_array_new_double_ptr = dlsym(handle, "mlx_array_new_double");
if (mlx_array_new_double_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_double\n");
return -1;
}
mlx_array_new_complex_ptr = dlsym(handle, "mlx_array_new_complex");
if (mlx_array_new_complex_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_complex\n");
return -1;
}
mlx_array_new_data_ptr = dlsym(handle, "mlx_array_new_data");
if (mlx_array_new_data_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_new_data\n");
return -1;
}
mlx_array_set_ptr = dlsym(handle, "mlx_array_set");
if (mlx_array_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set\n");
return -1;
}
mlx_array_set_bool_ptr = dlsym(handle, "mlx_array_set_bool");
if (mlx_array_set_bool_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_bool\n");
return -1;
}
mlx_array_set_int_ptr = dlsym(handle, "mlx_array_set_int");
if (mlx_array_set_int_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_int\n");
return -1;
}
mlx_array_set_float32_ptr = dlsym(handle, "mlx_array_set_float32");
if (mlx_array_set_float32_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_float32\n");
return -1;
}
mlx_array_set_float_ptr = dlsym(handle, "mlx_array_set_float");
if (mlx_array_set_float_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_float\n");
return -1;
}
mlx_array_set_float64_ptr = dlsym(handle, "mlx_array_set_float64");
if (mlx_array_set_float64_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_float64\n");
return -1;
}
mlx_array_set_double_ptr = dlsym(handle, "mlx_array_set_double");
if (mlx_array_set_double_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_double\n");
return -1;
}
mlx_array_set_complex_ptr = dlsym(handle, "mlx_array_set_complex");
if (mlx_array_set_complex_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_complex\n");
return -1;
}
mlx_array_set_data_ptr = dlsym(handle, "mlx_array_set_data");
if (mlx_array_set_data_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_set_data\n");
return -1;
}
mlx_array_itemsize_ptr = dlsym(handle, "mlx_array_itemsize");
if (mlx_array_itemsize_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_itemsize\n");
return -1;
}
mlx_array_size_ptr = dlsym(handle, "mlx_array_size");
if (mlx_array_size_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_size\n");
return -1;
}
mlx_array_nbytes_ptr = dlsym(handle, "mlx_array_nbytes");
if (mlx_array_nbytes_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_nbytes\n");
return -1;
}
mlx_array_ndim_ptr = dlsym(handle, "mlx_array_ndim");
if (mlx_array_ndim_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_ndim\n");
return -1;
}
mlx_array_shape_ptr = dlsym(handle, "mlx_array_shape");
if (mlx_array_shape_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_shape\n");
return -1;
}
mlx_array_strides_ptr = dlsym(handle, "mlx_array_strides");
if (mlx_array_strides_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_strides\n");
return -1;
}
mlx_array_dim_ptr = dlsym(handle, "mlx_array_dim");
if (mlx_array_dim_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_dim\n");
return -1;
}
mlx_array_dtype_ptr = dlsym(handle, "mlx_array_dtype");
if (mlx_array_dtype_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_dtype\n");
return -1;
}
mlx_array_eval_ptr = dlsym(handle, "mlx_array_eval");
if (mlx_array_eval_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_eval\n");
return -1;
}
mlx_array_item_bool_ptr = dlsym(handle, "mlx_array_item_bool");
if (mlx_array_item_bool_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_bool\n");
return -1;
}
mlx_array_item_uint8_ptr = dlsym(handle, "mlx_array_item_uint8");
if (mlx_array_item_uint8_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_uint8\n");
return -1;
}
mlx_array_item_uint16_ptr = dlsym(handle, "mlx_array_item_uint16");
if (mlx_array_item_uint16_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_uint16\n");
return -1;
}
mlx_array_item_uint32_ptr = dlsym(handle, "mlx_array_item_uint32");
if (mlx_array_item_uint32_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_uint32\n");
return -1;
}
mlx_array_item_uint64_ptr = dlsym(handle, "mlx_array_item_uint64");
if (mlx_array_item_uint64_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_uint64\n");
return -1;
}
mlx_array_item_int8_ptr = dlsym(handle, "mlx_array_item_int8");
if (mlx_array_item_int8_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_int8\n");
return -1;
}
mlx_array_item_int16_ptr = dlsym(handle, "mlx_array_item_int16");
if (mlx_array_item_int16_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_int16\n");
return -1;
}
mlx_array_item_int32_ptr = dlsym(handle, "mlx_array_item_int32");
if (mlx_array_item_int32_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_int32\n");
return -1;
}
mlx_array_item_int64_ptr = dlsym(handle, "mlx_array_item_int64");
if (mlx_array_item_int64_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_int64\n");
return -1;
}
mlx_array_item_float32_ptr = dlsym(handle, "mlx_array_item_float32");
if (mlx_array_item_float32_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_float32\n");
return -1;
}
mlx_array_item_float64_ptr = dlsym(handle, "mlx_array_item_float64");
if (mlx_array_item_float64_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_float64\n");
return -1;
}
mlx_array_item_complex64_ptr = dlsym(handle, "mlx_array_item_complex64");
if (mlx_array_item_complex64_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_complex64\n");
return -1;
}
#if defined(__aarch64__) || defined(_M_ARM64)
mlx_array_item_float16_ptr = dlsym(handle, "mlx_array_item_float16");
if (mlx_array_item_float16_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_float16\n");
return -1;
}
#endif
#if defined(__aarch64__) || defined(_M_ARM64)
mlx_array_item_bfloat16_ptr = dlsym(handle, "mlx_array_item_bfloat16");
if (mlx_array_item_bfloat16_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_item_bfloat16\n");
return -1;
}
#endif
mlx_array_data_bool_ptr = dlsym(handle, "mlx_array_data_bool");
if (mlx_array_data_bool_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_bool\n");
return -1;
}
mlx_array_data_uint8_ptr = dlsym(handle, "mlx_array_data_uint8");
if (mlx_array_data_uint8_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_uint8\n");
return -1;
}
mlx_array_data_uint16_ptr = dlsym(handle, "mlx_array_data_uint16");
if (mlx_array_data_uint16_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_uint16\n");
return -1;
}
mlx_array_data_uint32_ptr = dlsym(handle, "mlx_array_data_uint32");
if (mlx_array_data_uint32_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_uint32\n");
return -1;
}
mlx_array_data_uint64_ptr = dlsym(handle, "mlx_array_data_uint64");
if (mlx_array_data_uint64_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_uint64\n");
return -1;
}
mlx_array_data_int8_ptr = dlsym(handle, "mlx_array_data_int8");
if (mlx_array_data_int8_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_int8\n");
return -1;
}
mlx_array_data_int16_ptr = dlsym(handle, "mlx_array_data_int16");
if (mlx_array_data_int16_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_int16\n");
return -1;
}
mlx_array_data_int32_ptr = dlsym(handle, "mlx_array_data_int32");
if (mlx_array_data_int32_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_int32\n");
return -1;
}
mlx_array_data_int64_ptr = dlsym(handle, "mlx_array_data_int64");
if (mlx_array_data_int64_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_int64\n");
return -1;
}
mlx_array_data_float32_ptr = dlsym(handle, "mlx_array_data_float32");
if (mlx_array_data_float32_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_float32\n");
return -1;
}
mlx_array_data_float64_ptr = dlsym(handle, "mlx_array_data_float64");
if (mlx_array_data_float64_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_float64\n");
return -1;
}
mlx_array_data_complex64_ptr = dlsym(handle, "mlx_array_data_complex64");
if (mlx_array_data_complex64_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_complex64\n");
return -1;
}
#if defined(__aarch64__) || defined(_M_ARM64)
mlx_array_data_float16_ptr = dlsym(handle, "mlx_array_data_float16");
if (mlx_array_data_float16_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_float16\n");
return -1;
}
#endif
#if defined(__aarch64__) || defined(_M_ARM64)
mlx_array_data_bfloat16_ptr = dlsym(handle, "mlx_array_data_bfloat16");
if (mlx_array_data_bfloat16_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_data_bfloat16\n");
return -1;
}
#endif
_mlx_array_is_available_ptr = dlsym(handle, "_mlx_array_is_available");
if (_mlx_array_is_available_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: _mlx_array_is_available\n");
return -1;
}
_mlx_array_wait_ptr = dlsym(handle, "_mlx_array_wait");
if (_mlx_array_wait_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: _mlx_array_wait\n");
return -1;
}
_mlx_array_is_contiguous_ptr = dlsym(handle, "_mlx_array_is_contiguous");
if (_mlx_array_is_contiguous_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: _mlx_array_is_contiguous\n");
return -1;
}
_mlx_array_is_row_contiguous_ptr = dlsym(handle, "_mlx_array_is_row_contiguous");
if (_mlx_array_is_row_contiguous_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: _mlx_array_is_row_contiguous\n");
return -1;
}
_mlx_array_is_col_contiguous_ptr = dlsym(handle, "_mlx_array_is_col_contiguous");
if (_mlx_array_is_col_contiguous_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: _mlx_array_is_col_contiguous\n");
return -1;
}
mlx_closure_new_ptr = dlsym(handle, "mlx_closure_new");
if (mlx_closure_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_new\n");
return -1;
}
mlx_closure_free_ptr = dlsym(handle, "mlx_closure_free");
if (mlx_closure_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_free\n");
return -1;
}
mlx_closure_new_func_ptr = dlsym(handle, "mlx_closure_new_func");
if (mlx_closure_new_func_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_new_func\n");
return -1;
}
mlx_closure_new_func_payload_ptr = dlsym(handle, "mlx_closure_new_func_payload");
if (mlx_closure_new_func_payload_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_new_func_payload\n");
return -1;
}
mlx_closure_set_ptr = dlsym(handle, "mlx_closure_set");
if (mlx_closure_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_set\n");
return -1;
}
mlx_closure_apply_ptr = dlsym(handle, "mlx_closure_apply");
if (mlx_closure_apply_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_apply\n");
return -1;
}
mlx_closure_new_unary_ptr = dlsym(handle, "mlx_closure_new_unary");
if (mlx_closure_new_unary_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_new_unary\n");
return -1;
}
mlx_closure_kwargs_new_ptr = dlsym(handle, "mlx_closure_kwargs_new");
if (mlx_closure_kwargs_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_new\n");
return -1;
}
mlx_closure_kwargs_free_ptr = dlsym(handle, "mlx_closure_kwargs_free");
if (mlx_closure_kwargs_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_free\n");
return -1;
}
mlx_closure_kwargs_new_func_ptr = dlsym(handle, "mlx_closure_kwargs_new_func");
if (mlx_closure_kwargs_new_func_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_new_func\n");
return -1;
}
mlx_closure_kwargs_new_func_payload_ptr = dlsym(handle, "mlx_closure_kwargs_new_func_payload");
if (mlx_closure_kwargs_new_func_payload_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_new_func_payload\n");
return -1;
}
mlx_closure_kwargs_set_ptr = dlsym(handle, "mlx_closure_kwargs_set");
if (mlx_closure_kwargs_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_set\n");
return -1;
}
mlx_closure_kwargs_apply_ptr = dlsym(handle, "mlx_closure_kwargs_apply");
if (mlx_closure_kwargs_apply_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_kwargs_apply\n");
return -1;
}
mlx_closure_value_and_grad_new_ptr = dlsym(handle, "mlx_closure_value_and_grad_new");
if (mlx_closure_value_and_grad_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_new\n");
return -1;
}
mlx_closure_value_and_grad_free_ptr = dlsym(handle, "mlx_closure_value_and_grad_free");
if (mlx_closure_value_and_grad_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_free\n");
return -1;
}
mlx_closure_value_and_grad_new_func_ptr = dlsym(handle, "mlx_closure_value_and_grad_new_func");
if (mlx_closure_value_and_grad_new_func_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_new_func\n");
return -1;
}
mlx_closure_value_and_grad_new_func_payload_ptr = dlsym(handle, "mlx_closure_value_and_grad_new_func_payload");
if (mlx_closure_value_and_grad_new_func_payload_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_new_func_payload\n");
return -1;
}
mlx_closure_value_and_grad_set_ptr = dlsym(handle, "mlx_closure_value_and_grad_set");
if (mlx_closure_value_and_grad_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_set\n");
return -1;
}
mlx_closure_value_and_grad_apply_ptr = dlsym(handle, "mlx_closure_value_and_grad_apply");
if (mlx_closure_value_and_grad_apply_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_value_and_grad_apply\n");
return -1;
}
mlx_closure_custom_new_ptr = dlsym(handle, "mlx_closure_custom_new");
if (mlx_closure_custom_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_new\n");
return -1;
}
mlx_closure_custom_free_ptr = dlsym(handle, "mlx_closure_custom_free");
if (mlx_closure_custom_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_free\n");
return -1;
}
mlx_closure_custom_new_func_ptr = dlsym(handle, "mlx_closure_custom_new_func");
if (mlx_closure_custom_new_func_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_new_func\n");
return -1;
}
mlx_closure_custom_new_func_payload_ptr = dlsym(handle, "mlx_closure_custom_new_func_payload");
if (mlx_closure_custom_new_func_payload_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_new_func_payload\n");
return -1;
}
mlx_closure_custom_set_ptr = dlsym(handle, "mlx_closure_custom_set");
if (mlx_closure_custom_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_set\n");
return -1;
}
mlx_closure_custom_apply_ptr = dlsym(handle, "mlx_closure_custom_apply");
if (mlx_closure_custom_apply_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_apply\n");
return -1;
}
mlx_closure_custom_jvp_new_ptr = dlsym(handle, "mlx_closure_custom_jvp_new");
if (mlx_closure_custom_jvp_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_new\n");
return -1;
}
mlx_closure_custom_jvp_free_ptr = dlsym(handle, "mlx_closure_custom_jvp_free");
if (mlx_closure_custom_jvp_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_free\n");
return -1;
}
mlx_closure_custom_jvp_new_func_ptr = dlsym(handle, "mlx_closure_custom_jvp_new_func");
if (mlx_closure_custom_jvp_new_func_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_new_func\n");
return -1;
}
mlx_closure_custom_jvp_new_func_payload_ptr = dlsym(handle, "mlx_closure_custom_jvp_new_func_payload");
if (mlx_closure_custom_jvp_new_func_payload_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_new_func_payload\n");
return -1;
}
mlx_closure_custom_jvp_set_ptr = dlsym(handle, "mlx_closure_custom_jvp_set");
if (mlx_closure_custom_jvp_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_set\n");
return -1;
}
mlx_closure_custom_jvp_apply_ptr = dlsym(handle, "mlx_closure_custom_jvp_apply");
if (mlx_closure_custom_jvp_apply_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_jvp_apply\n");
return -1;
}
mlx_closure_custom_vmap_new_ptr = dlsym(handle, "mlx_closure_custom_vmap_new");
if (mlx_closure_custom_vmap_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_new\n");
return -1;
}
mlx_closure_custom_vmap_free_ptr = dlsym(handle, "mlx_closure_custom_vmap_free");
if (mlx_closure_custom_vmap_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_free\n");
return -1;
}
mlx_closure_custom_vmap_new_func_ptr = dlsym(handle, "mlx_closure_custom_vmap_new_func");
if (mlx_closure_custom_vmap_new_func_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_new_func\n");
return -1;
}
mlx_closure_custom_vmap_new_func_payload_ptr = dlsym(handle, "mlx_closure_custom_vmap_new_func_payload");
if (mlx_closure_custom_vmap_new_func_payload_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_new_func_payload\n");
return -1;
}
mlx_closure_custom_vmap_set_ptr = dlsym(handle, "mlx_closure_custom_vmap_set");
if (mlx_closure_custom_vmap_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_set\n");
return -1;
}
mlx_closure_custom_vmap_apply_ptr = dlsym(handle, "mlx_closure_custom_vmap_apply");
if (mlx_closure_custom_vmap_apply_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_closure_custom_vmap_apply\n");
return -1;
}
mlx_compile_ptr = dlsym(handle, "mlx_compile");
if (mlx_compile_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_compile\n");
return -1;
}
mlx_detail_compile_ptr = dlsym(handle, "mlx_detail_compile");
if (mlx_detail_compile_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_detail_compile\n");
return -1;
}
mlx_detail_compile_clear_cache_ptr = dlsym(handle, "mlx_detail_compile_clear_cache");
if (mlx_detail_compile_clear_cache_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_detail_compile_clear_cache\n");
return -1;
}
mlx_detail_compile_erase_ptr = dlsym(handle, "mlx_detail_compile_erase");
if (mlx_detail_compile_erase_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_detail_compile_erase\n");
return -1;
}
mlx_disable_compile_ptr = dlsym(handle, "mlx_disable_compile");
if (mlx_disable_compile_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_disable_compile\n");
return -1;
}
mlx_enable_compile_ptr = dlsym(handle, "mlx_enable_compile");
if (mlx_enable_compile_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_enable_compile\n");
return -1;
}
mlx_set_compile_mode_ptr = dlsym(handle, "mlx_set_compile_mode");
if (mlx_set_compile_mode_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_compile_mode\n");
return -1;
}
mlx_device_new_ptr = dlsym(handle, "mlx_device_new");
if (mlx_device_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_new\n");
return -1;
}
mlx_device_new_type_ptr = dlsym(handle, "mlx_device_new_type");
if (mlx_device_new_type_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_new_type\n");
return -1;
}
mlx_device_free_ptr = dlsym(handle, "mlx_device_free");
if (mlx_device_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_free\n");
return -1;
}
mlx_device_set_ptr = dlsym(handle, "mlx_device_set");
if (mlx_device_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_set\n");
return -1;
}
mlx_device_tostring_ptr = dlsym(handle, "mlx_device_tostring");
if (mlx_device_tostring_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_tostring\n");
return -1;
}
mlx_device_equal_ptr = dlsym(handle, "mlx_device_equal");
if (mlx_device_equal_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_equal\n");
return -1;
}
mlx_device_get_index_ptr = dlsym(handle, "mlx_device_get_index");
if (mlx_device_get_index_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_get_index\n");
return -1;
}
mlx_device_get_type_ptr = dlsym(handle, "mlx_device_get_type");
if (mlx_device_get_type_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_device_get_type\n");
return -1;
}
mlx_get_default_device_ptr = dlsym(handle, "mlx_get_default_device");
if (mlx_get_default_device_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_get_default_device\n");
return -1;
}
mlx_set_default_device_ptr = dlsym(handle, "mlx_set_default_device");
if (mlx_set_default_device_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_default_device\n");
return -1;
}
mlx_distributed_all_gather_ptr = dlsym(handle, "mlx_distributed_all_gather");
if (mlx_distributed_all_gather_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_gather\n");
return -1;
}
mlx_distributed_all_max_ptr = dlsym(handle, "mlx_distributed_all_max");
if (mlx_distributed_all_max_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_max\n");
return -1;
}
mlx_distributed_all_min_ptr = dlsym(handle, "mlx_distributed_all_min");
if (mlx_distributed_all_min_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_min\n");
return -1;
}
mlx_distributed_all_sum_ptr = dlsym(handle, "mlx_distributed_all_sum");
if (mlx_distributed_all_sum_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_all_sum\n");
return -1;
}
mlx_distributed_recv_ptr = dlsym(handle, "mlx_distributed_recv");
if (mlx_distributed_recv_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_recv\n");
return -1;
}
mlx_distributed_recv_like_ptr = dlsym(handle, "mlx_distributed_recv_like");
if (mlx_distributed_recv_like_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_recv_like\n");
return -1;
}
mlx_distributed_send_ptr = dlsym(handle, "mlx_distributed_send");
if (mlx_distributed_send_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_send\n");
return -1;
}
mlx_distributed_sum_scatter_ptr = dlsym(handle, "mlx_distributed_sum_scatter");
if (mlx_distributed_sum_scatter_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_sum_scatter\n");
return -1;
}
mlx_distributed_group_rank_ptr = dlsym(handle, "mlx_distributed_group_rank");
if (mlx_distributed_group_rank_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_group_rank\n");
return -1;
}
mlx_distributed_group_size_ptr = dlsym(handle, "mlx_distributed_group_size");
if (mlx_distributed_group_size_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_group_size\n");
return -1;
}
mlx_distributed_group_split_ptr = dlsym(handle, "mlx_distributed_group_split");
if (mlx_distributed_group_split_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_group_split\n");
return -1;
}
mlx_distributed_is_available_ptr = dlsym(handle, "mlx_distributed_is_available");
if (mlx_distributed_is_available_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_is_available\n");
return -1;
}
mlx_distributed_init_ptr = dlsym(handle, "mlx_distributed_init");
if (mlx_distributed_init_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_distributed_init\n");
return -1;
}
mlx_set_error_handler_ptr = dlsym(handle, "mlx_set_error_handler");
if (mlx_set_error_handler_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_error_handler\n");
return -1;
}
_mlx_error_ptr = dlsym(handle, "_mlx_error");
if (_mlx_error_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: _mlx_error\n");
return -1;
}
mlx_export_function_ptr = dlsym(handle, "mlx_export_function");
if (mlx_export_function_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_export_function\n");
return -1;
}
mlx_export_function_kwargs_ptr = dlsym(handle, "mlx_export_function_kwargs");
if (mlx_export_function_kwargs_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_export_function_kwargs\n");
return -1;
}
mlx_function_exporter_new_ptr = dlsym(handle, "mlx_function_exporter_new");
if (mlx_function_exporter_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_function_exporter_new\n");
return -1;
}
mlx_function_exporter_free_ptr = dlsym(handle, "mlx_function_exporter_free");
if (mlx_function_exporter_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_function_exporter_free\n");
return -1;
}
mlx_function_exporter_apply_ptr = dlsym(handle, "mlx_function_exporter_apply");
if (mlx_function_exporter_apply_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_function_exporter_apply\n");
return -1;
}
mlx_function_exporter_apply_kwargs_ptr = dlsym(handle, "mlx_function_exporter_apply_kwargs");
if (mlx_function_exporter_apply_kwargs_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_function_exporter_apply_kwargs\n");
return -1;
}
mlx_imported_function_new_ptr = dlsym(handle, "mlx_imported_function_new");
if (mlx_imported_function_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_imported_function_new\n");
return -1;
}
mlx_imported_function_free_ptr = dlsym(handle, "mlx_imported_function_free");
if (mlx_imported_function_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_imported_function_free\n");
return -1;
}
mlx_imported_function_apply_ptr = dlsym(handle, "mlx_imported_function_apply");
if (mlx_imported_function_apply_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_imported_function_apply\n");
return -1;
}
mlx_imported_function_apply_kwargs_ptr = dlsym(handle, "mlx_imported_function_apply_kwargs");
if (mlx_imported_function_apply_kwargs_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_imported_function_apply_kwargs\n");
return -1;
}
mlx_fast_cuda_kernel_config_new_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_new");
if (mlx_fast_cuda_kernel_config_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_new\n");
return -1;
}
mlx_fast_cuda_kernel_config_free_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_free");
if (mlx_fast_cuda_kernel_config_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_free\n");
return -1;
}
mlx_fast_cuda_kernel_config_add_output_arg_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_add_output_arg");
if (mlx_fast_cuda_kernel_config_add_output_arg_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_add_output_arg\n");
return -1;
}
mlx_fast_cuda_kernel_config_set_grid_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_set_grid");
if (mlx_fast_cuda_kernel_config_set_grid_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_set_grid\n");
return -1;
}
mlx_fast_cuda_kernel_config_set_thread_group_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_set_thread_group");
if (mlx_fast_cuda_kernel_config_set_thread_group_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_set_thread_group\n");
return -1;
}
mlx_fast_cuda_kernel_config_set_init_value_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_set_init_value");
if (mlx_fast_cuda_kernel_config_set_init_value_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_set_init_value\n");
return -1;
}
mlx_fast_cuda_kernel_config_set_verbose_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_set_verbose");
if (mlx_fast_cuda_kernel_config_set_verbose_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_set_verbose\n");
return -1;
}
mlx_fast_cuda_kernel_config_add_template_arg_dtype_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_add_template_arg_dtype");
if (mlx_fast_cuda_kernel_config_add_template_arg_dtype_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_add_template_arg_dtype\n");
return -1;
}
mlx_fast_cuda_kernel_config_add_template_arg_int_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_add_template_arg_int");
if (mlx_fast_cuda_kernel_config_add_template_arg_int_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_add_template_arg_int\n");
return -1;
}
mlx_fast_cuda_kernel_config_add_template_arg_bool_ptr = dlsym(handle, "mlx_fast_cuda_kernel_config_add_template_arg_bool");
if (mlx_fast_cuda_kernel_config_add_template_arg_bool_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_config_add_template_arg_bool\n");
return -1;
}
mlx_fast_cuda_kernel_new_ptr = dlsym(handle, "mlx_fast_cuda_kernel_new");
if (mlx_fast_cuda_kernel_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_new\n");
return -1;
}
mlx_fast_cuda_kernel_free_ptr = dlsym(handle, "mlx_fast_cuda_kernel_free");
if (mlx_fast_cuda_kernel_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_free\n");
return -1;
}
mlx_fast_cuda_kernel_apply_ptr = dlsym(handle, "mlx_fast_cuda_kernel_apply");
if (mlx_fast_cuda_kernel_apply_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_cuda_kernel_apply\n");
return -1;
}
mlx_fast_layer_norm_ptr = dlsym(handle, "mlx_fast_layer_norm");
if (mlx_fast_layer_norm_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_layer_norm\n");
return -1;
}
mlx_fast_metal_kernel_config_new_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_new");
if (mlx_fast_metal_kernel_config_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_new\n");
return -1;
}
mlx_fast_metal_kernel_config_free_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_free");
if (mlx_fast_metal_kernel_config_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_free\n");
return -1;
}
mlx_fast_metal_kernel_config_add_output_arg_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_add_output_arg");
if (mlx_fast_metal_kernel_config_add_output_arg_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_add_output_arg\n");
return -1;
}
mlx_fast_metal_kernel_config_set_grid_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_set_grid");
if (mlx_fast_metal_kernel_config_set_grid_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_set_grid\n");
return -1;
}
mlx_fast_metal_kernel_config_set_thread_group_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_set_thread_group");
if (mlx_fast_metal_kernel_config_set_thread_group_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_set_thread_group\n");
return -1;
}
mlx_fast_metal_kernel_config_set_init_value_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_set_init_value");
if (mlx_fast_metal_kernel_config_set_init_value_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_set_init_value\n");
return -1;
}
mlx_fast_metal_kernel_config_set_verbose_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_set_verbose");
if (mlx_fast_metal_kernel_config_set_verbose_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_set_verbose\n");
return -1;
}
mlx_fast_metal_kernel_config_add_template_arg_dtype_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_add_template_arg_dtype");
if (mlx_fast_metal_kernel_config_add_template_arg_dtype_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_add_template_arg_dtype\n");
return -1;
}
mlx_fast_metal_kernel_config_add_template_arg_int_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_add_template_arg_int");
if (mlx_fast_metal_kernel_config_add_template_arg_int_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_add_template_arg_int\n");
return -1;
}
mlx_fast_metal_kernel_config_add_template_arg_bool_ptr = dlsym(handle, "mlx_fast_metal_kernel_config_add_template_arg_bool");
if (mlx_fast_metal_kernel_config_add_template_arg_bool_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_config_add_template_arg_bool\n");
return -1;
}
mlx_fast_metal_kernel_new_ptr = dlsym(handle, "mlx_fast_metal_kernel_new");
if (mlx_fast_metal_kernel_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_new\n");
return -1;
}
mlx_fast_metal_kernel_free_ptr = dlsym(handle, "mlx_fast_metal_kernel_free");
if (mlx_fast_metal_kernel_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_free\n");
return -1;
}
mlx_fast_metal_kernel_apply_ptr = dlsym(handle, "mlx_fast_metal_kernel_apply");
if (mlx_fast_metal_kernel_apply_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_metal_kernel_apply\n");
return -1;
}
mlx_fast_rms_norm_ptr = dlsym(handle, "mlx_fast_rms_norm");
if (mlx_fast_rms_norm_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_rms_norm\n");
return -1;
}
mlx_fast_rope_ptr = dlsym(handle, "mlx_fast_rope");
if (mlx_fast_rope_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_rope\n");
return -1;
}
mlx_fast_rope_dynamic_ptr = dlsym(handle, "mlx_fast_rope_dynamic");
if (mlx_fast_rope_dynamic_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_rope_dynamic\n");
return -1;
}
mlx_fast_scaled_dot_product_attention_ptr = dlsym(handle, "mlx_fast_scaled_dot_product_attention");
if (mlx_fast_scaled_dot_product_attention_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fast_scaled_dot_product_attention\n");
return -1;
}
mlx_fft_fft_ptr = dlsym(handle, "mlx_fft_fft");
if (mlx_fft_fft_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_fft\n");
return -1;
}
mlx_fft_fft2_ptr = dlsym(handle, "mlx_fft_fft2");
if (mlx_fft_fft2_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_fft2\n");
return -1;
}
mlx_fft_fftn_ptr = dlsym(handle, "mlx_fft_fftn");
if (mlx_fft_fftn_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_fftn\n");
return -1;
}
mlx_fft_fftshift_ptr = dlsym(handle, "mlx_fft_fftshift");
if (mlx_fft_fftshift_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_fftshift\n");
return -1;
}
mlx_fft_ifft_ptr = dlsym(handle, "mlx_fft_ifft");
if (mlx_fft_ifft_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_ifft\n");
return -1;
}
mlx_fft_ifft2_ptr = dlsym(handle, "mlx_fft_ifft2");
if (mlx_fft_ifft2_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_ifft2\n");
return -1;
}
mlx_fft_ifftn_ptr = dlsym(handle, "mlx_fft_ifftn");
if (mlx_fft_ifftn_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_ifftn\n");
return -1;
}
mlx_fft_ifftshift_ptr = dlsym(handle, "mlx_fft_ifftshift");
if (mlx_fft_ifftshift_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_ifftshift\n");
return -1;
}
mlx_fft_irfft_ptr = dlsym(handle, "mlx_fft_irfft");
if (mlx_fft_irfft_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_irfft\n");
return -1;
}
mlx_fft_irfft2_ptr = dlsym(handle, "mlx_fft_irfft2");
if (mlx_fft_irfft2_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_irfft2\n");
return -1;
}
mlx_fft_irfftn_ptr = dlsym(handle, "mlx_fft_irfftn");
if (mlx_fft_irfftn_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_irfftn\n");
return -1;
}
mlx_fft_rfft_ptr = dlsym(handle, "mlx_fft_rfft");
if (mlx_fft_rfft_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_rfft\n");
return -1;
}
mlx_fft_rfft2_ptr = dlsym(handle, "mlx_fft_rfft2");
if (mlx_fft_rfft2_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_rfft2\n");
return -1;
}
mlx_fft_rfftn_ptr = dlsym(handle, "mlx_fft_rfftn");
if (mlx_fft_rfftn_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_fft_rfftn\n");
return -1;
}
mlx_load_reader_ptr = dlsym(handle, "mlx_load_reader");
if (mlx_load_reader_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_load_reader\n");
return -1;
}
mlx_load_ptr = dlsym(handle, "mlx_load");
if (mlx_load_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_load\n");
return -1;
}
mlx_load_safetensors_reader_ptr = dlsym(handle, "mlx_load_safetensors_reader");
if (mlx_load_safetensors_reader_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_load_safetensors_reader\n");
return -1;
}
mlx_load_safetensors_ptr = dlsym(handle, "mlx_load_safetensors");
if (mlx_load_safetensors_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_load_safetensors\n");
return -1;
}
mlx_save_writer_ptr = dlsym(handle, "mlx_save_writer");
if (mlx_save_writer_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_save_writer\n");
return -1;
}
mlx_save_ptr = dlsym(handle, "mlx_save");
if (mlx_save_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_save\n");
return -1;
}
mlx_save_safetensors_writer_ptr = dlsym(handle, "mlx_save_safetensors_writer");
if (mlx_save_safetensors_writer_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_save_safetensors_writer\n");
return -1;
}
mlx_save_safetensors_ptr = dlsym(handle, "mlx_save_safetensors");
if (mlx_save_safetensors_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_save_safetensors\n");
return -1;
}
mlx_io_reader_new_ptr = dlsym(handle, "mlx_io_reader_new");
if (mlx_io_reader_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_io_reader_new\n");
return -1;
}
mlx_io_reader_descriptor_ptr = dlsym(handle, "mlx_io_reader_descriptor");
if (mlx_io_reader_descriptor_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_io_reader_descriptor\n");
return -1;
}
mlx_io_reader_tostring_ptr = dlsym(handle, "mlx_io_reader_tostring");
if (mlx_io_reader_tostring_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_io_reader_tostring\n");
return -1;
}
mlx_io_reader_free_ptr = dlsym(handle, "mlx_io_reader_free");
if (mlx_io_reader_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_io_reader_free\n");
return -1;
}
mlx_io_writer_new_ptr = dlsym(handle, "mlx_io_writer_new");
if (mlx_io_writer_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_io_writer_new\n");
return -1;
}
mlx_io_writer_descriptor_ptr = dlsym(handle, "mlx_io_writer_descriptor");
if (mlx_io_writer_descriptor_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_io_writer_descriptor\n");
return -1;
}
mlx_io_writer_tostring_ptr = dlsym(handle, "mlx_io_writer_tostring");
if (mlx_io_writer_tostring_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_io_writer_tostring\n");
return -1;
}
mlx_io_writer_free_ptr = dlsym(handle, "mlx_io_writer_free");
if (mlx_io_writer_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_io_writer_free\n");
return -1;
}
mlx_linalg_cholesky_ptr = dlsym(handle, "mlx_linalg_cholesky");
if (mlx_linalg_cholesky_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_cholesky\n");
return -1;
}
mlx_linalg_cholesky_inv_ptr = dlsym(handle, "mlx_linalg_cholesky_inv");
if (mlx_linalg_cholesky_inv_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_cholesky_inv\n");
return -1;
}
mlx_linalg_cross_ptr = dlsym(handle, "mlx_linalg_cross");
if (mlx_linalg_cross_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_cross\n");
return -1;
}
mlx_linalg_eig_ptr = dlsym(handle, "mlx_linalg_eig");
if (mlx_linalg_eig_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_eig\n");
return -1;
}
mlx_linalg_eigh_ptr = dlsym(handle, "mlx_linalg_eigh");
if (mlx_linalg_eigh_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_eigh\n");
return -1;
}
mlx_linalg_eigvals_ptr = dlsym(handle, "mlx_linalg_eigvals");
if (mlx_linalg_eigvals_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_eigvals\n");
return -1;
}
mlx_linalg_eigvalsh_ptr = dlsym(handle, "mlx_linalg_eigvalsh");
if (mlx_linalg_eigvalsh_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_eigvalsh\n");
return -1;
}
mlx_linalg_inv_ptr = dlsym(handle, "mlx_linalg_inv");
if (mlx_linalg_inv_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_inv\n");
return -1;
}
mlx_linalg_lu_ptr = dlsym(handle, "mlx_linalg_lu");
if (mlx_linalg_lu_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_lu\n");
return -1;
}
mlx_linalg_lu_factor_ptr = dlsym(handle, "mlx_linalg_lu_factor");
if (mlx_linalg_lu_factor_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_lu_factor\n");
return -1;
}
mlx_linalg_norm_ptr = dlsym(handle, "mlx_linalg_norm");
if (mlx_linalg_norm_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_norm\n");
return -1;
}
mlx_linalg_norm_matrix_ptr = dlsym(handle, "mlx_linalg_norm_matrix");
if (mlx_linalg_norm_matrix_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_norm_matrix\n");
return -1;
}
mlx_linalg_norm_l2_ptr = dlsym(handle, "mlx_linalg_norm_l2");
if (mlx_linalg_norm_l2_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_norm_l2\n");
return -1;
}
mlx_linalg_pinv_ptr = dlsym(handle, "mlx_linalg_pinv");
if (mlx_linalg_pinv_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_pinv\n");
return -1;
}
mlx_linalg_qr_ptr = dlsym(handle, "mlx_linalg_qr");
if (mlx_linalg_qr_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_qr\n");
return -1;
}
mlx_linalg_solve_ptr = dlsym(handle, "mlx_linalg_solve");
if (mlx_linalg_solve_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_solve\n");
return -1;
}
mlx_linalg_solve_triangular_ptr = dlsym(handle, "mlx_linalg_solve_triangular");
if (mlx_linalg_solve_triangular_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_solve_triangular\n");
return -1;
}
mlx_linalg_svd_ptr = dlsym(handle, "mlx_linalg_svd");
if (mlx_linalg_svd_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_svd\n");
return -1;
}
mlx_linalg_tri_inv_ptr = dlsym(handle, "mlx_linalg_tri_inv");
if (mlx_linalg_tri_inv_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linalg_tri_inv\n");
return -1;
}
mlx_map_string_to_array_new_ptr = dlsym(handle, "mlx_map_string_to_array_new");
if (mlx_map_string_to_array_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_new\n");
return -1;
}
mlx_map_string_to_array_set_ptr = dlsym(handle, "mlx_map_string_to_array_set");
if (mlx_map_string_to_array_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_set\n");
return -1;
}
mlx_map_string_to_array_free_ptr = dlsym(handle, "mlx_map_string_to_array_free");
if (mlx_map_string_to_array_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_free\n");
return -1;
}
mlx_map_string_to_array_insert_ptr = dlsym(handle, "mlx_map_string_to_array_insert");
if (mlx_map_string_to_array_insert_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_insert\n");
return -1;
}
mlx_map_string_to_array_get_ptr = dlsym(handle, "mlx_map_string_to_array_get");
if (mlx_map_string_to_array_get_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_get\n");
return -1;
}
mlx_map_string_to_array_iterator_new_ptr = dlsym(handle, "mlx_map_string_to_array_iterator_new");
if (mlx_map_string_to_array_iterator_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_iterator_new\n");
return -1;
}
mlx_map_string_to_array_iterator_free_ptr = dlsym(handle, "mlx_map_string_to_array_iterator_free");
if (mlx_map_string_to_array_iterator_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_iterator_free\n");
return -1;
}
mlx_map_string_to_array_iterator_next_ptr = dlsym(handle, "mlx_map_string_to_array_iterator_next");
if (mlx_map_string_to_array_iterator_next_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_array_iterator_next\n");
return -1;
}
mlx_map_string_to_string_new_ptr = dlsym(handle, "mlx_map_string_to_string_new");
if (mlx_map_string_to_string_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_new\n");
return -1;
}
mlx_map_string_to_string_set_ptr = dlsym(handle, "mlx_map_string_to_string_set");
if (mlx_map_string_to_string_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_set\n");
return -1;
}
mlx_map_string_to_string_free_ptr = dlsym(handle, "mlx_map_string_to_string_free");
if (mlx_map_string_to_string_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_free\n");
return -1;
}
mlx_map_string_to_string_insert_ptr = dlsym(handle, "mlx_map_string_to_string_insert");
if (mlx_map_string_to_string_insert_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_insert\n");
return -1;
}
mlx_map_string_to_string_get_ptr = dlsym(handle, "mlx_map_string_to_string_get");
if (mlx_map_string_to_string_get_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_get\n");
return -1;
}
mlx_map_string_to_string_iterator_new_ptr = dlsym(handle, "mlx_map_string_to_string_iterator_new");
if (mlx_map_string_to_string_iterator_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_iterator_new\n");
return -1;
}
mlx_map_string_to_string_iterator_free_ptr = dlsym(handle, "mlx_map_string_to_string_iterator_free");
if (mlx_map_string_to_string_iterator_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_iterator_free\n");
return -1;
}
mlx_map_string_to_string_iterator_next_ptr = dlsym(handle, "mlx_map_string_to_string_iterator_next");
if (mlx_map_string_to_string_iterator_next_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_map_string_to_string_iterator_next\n");
return -1;
}
mlx_clear_cache_ptr = dlsym(handle, "mlx_clear_cache");
if (mlx_clear_cache_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_clear_cache\n");
return -1;
}
mlx_get_active_memory_ptr = dlsym(handle, "mlx_get_active_memory");
if (mlx_get_active_memory_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_get_active_memory\n");
return -1;
}
mlx_get_cache_memory_ptr = dlsym(handle, "mlx_get_cache_memory");
if (mlx_get_cache_memory_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_get_cache_memory\n");
return -1;
}
mlx_get_memory_limit_ptr = dlsym(handle, "mlx_get_memory_limit");
if (mlx_get_memory_limit_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_get_memory_limit\n");
return -1;
}
mlx_get_peak_memory_ptr = dlsym(handle, "mlx_get_peak_memory");
if (mlx_get_peak_memory_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_get_peak_memory\n");
return -1;
}
mlx_reset_peak_memory_ptr = dlsym(handle, "mlx_reset_peak_memory");
if (mlx_reset_peak_memory_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_reset_peak_memory\n");
return -1;
}
mlx_set_cache_limit_ptr = dlsym(handle, "mlx_set_cache_limit");
if (mlx_set_cache_limit_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_cache_limit\n");
return -1;
}
mlx_set_memory_limit_ptr = dlsym(handle, "mlx_set_memory_limit");
if (mlx_set_memory_limit_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_memory_limit\n");
return -1;
}
mlx_set_wired_limit_ptr = dlsym(handle, "mlx_set_wired_limit");
if (mlx_set_wired_limit_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_wired_limit\n");
return -1;
}
mlx_metal_device_info_ptr = dlsym(handle, "mlx_metal_device_info");
if (mlx_metal_device_info_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_device_info\n");
return -1;
}
mlx_metal_is_available_ptr = dlsym(handle, "mlx_metal_is_available");
if (mlx_metal_is_available_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_is_available\n");
return -1;
}
mlx_metal_start_capture_ptr = dlsym(handle, "mlx_metal_start_capture");
if (mlx_metal_start_capture_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_start_capture\n");
return -1;
}
mlx_metal_stop_capture_ptr = dlsym(handle, "mlx_metal_stop_capture");
if (mlx_metal_stop_capture_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_metal_stop_capture\n");
return -1;
}
mlx_abs_ptr = dlsym(handle, "mlx_abs");
if (mlx_abs_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_abs\n");
return -1;
}
mlx_add_ptr = dlsym(handle, "mlx_add");
if (mlx_add_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_add\n");
return -1;
}
mlx_addmm_ptr = dlsym(handle, "mlx_addmm");
if (mlx_addmm_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_addmm\n");
return -1;
}
mlx_all_axes_ptr = dlsym(handle, "mlx_all_axes");
if (mlx_all_axes_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_all_axes\n");
return -1;
}
mlx_all_axis_ptr = dlsym(handle, "mlx_all_axis");
if (mlx_all_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_all_axis\n");
return -1;
}
mlx_all_ptr = dlsym(handle, "mlx_all");
if (mlx_all_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_all\n");
return -1;
}
mlx_allclose_ptr = dlsym(handle, "mlx_allclose");
if (mlx_allclose_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_allclose\n");
return -1;
}
mlx_any_axes_ptr = dlsym(handle, "mlx_any_axes");
if (mlx_any_axes_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_any_axes\n");
return -1;
}
mlx_any_axis_ptr = dlsym(handle, "mlx_any_axis");
if (mlx_any_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_any_axis\n");
return -1;
}
mlx_any_ptr = dlsym(handle, "mlx_any");
if (mlx_any_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_any\n");
return -1;
}
mlx_arange_ptr = dlsym(handle, "mlx_arange");
if (mlx_arange_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_arange\n");
return -1;
}
mlx_arccos_ptr = dlsym(handle, "mlx_arccos");
if (mlx_arccos_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_arccos\n");
return -1;
}
mlx_arccosh_ptr = dlsym(handle, "mlx_arccosh");
if (mlx_arccosh_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_arccosh\n");
return -1;
}
mlx_arcsin_ptr = dlsym(handle, "mlx_arcsin");
if (mlx_arcsin_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_arcsin\n");
return -1;
}
mlx_arcsinh_ptr = dlsym(handle, "mlx_arcsinh");
if (mlx_arcsinh_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_arcsinh\n");
return -1;
}
mlx_arctan_ptr = dlsym(handle, "mlx_arctan");
if (mlx_arctan_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_arctan\n");
return -1;
}
mlx_arctan2_ptr = dlsym(handle, "mlx_arctan2");
if (mlx_arctan2_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_arctan2\n");
return -1;
}
mlx_arctanh_ptr = dlsym(handle, "mlx_arctanh");
if (mlx_arctanh_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_arctanh\n");
return -1;
}
mlx_argmax_axis_ptr = dlsym(handle, "mlx_argmax_axis");
if (mlx_argmax_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_argmax_axis\n");
return -1;
}
mlx_argmax_ptr = dlsym(handle, "mlx_argmax");
if (mlx_argmax_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_argmax\n");
return -1;
}
mlx_argmin_axis_ptr = dlsym(handle, "mlx_argmin_axis");
if (mlx_argmin_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_argmin_axis\n");
return -1;
}
mlx_argmin_ptr = dlsym(handle, "mlx_argmin");
if (mlx_argmin_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_argmin\n");
return -1;
}
mlx_argpartition_axis_ptr = dlsym(handle, "mlx_argpartition_axis");
if (mlx_argpartition_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_argpartition_axis\n");
return -1;
}
mlx_argpartition_ptr = dlsym(handle, "mlx_argpartition");
if (mlx_argpartition_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_argpartition\n");
return -1;
}
mlx_argsort_axis_ptr = dlsym(handle, "mlx_argsort_axis");
if (mlx_argsort_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_argsort_axis\n");
return -1;
}
mlx_argsort_ptr = dlsym(handle, "mlx_argsort");
if (mlx_argsort_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_argsort\n");
return -1;
}
mlx_array_equal_ptr = dlsym(handle, "mlx_array_equal");
if (mlx_array_equal_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_array_equal\n");
return -1;
}
mlx_as_strided_ptr = dlsym(handle, "mlx_as_strided");
if (mlx_as_strided_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_as_strided\n");
return -1;
}
mlx_astype_ptr = dlsym(handle, "mlx_astype");
if (mlx_astype_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_astype\n");
return -1;
}
mlx_atleast_1d_ptr = dlsym(handle, "mlx_atleast_1d");
if (mlx_atleast_1d_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_atleast_1d\n");
return -1;
}
mlx_atleast_2d_ptr = dlsym(handle, "mlx_atleast_2d");
if (mlx_atleast_2d_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_atleast_2d\n");
return -1;
}
mlx_atleast_3d_ptr = dlsym(handle, "mlx_atleast_3d");
if (mlx_atleast_3d_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_atleast_3d\n");
return -1;
}
mlx_bitwise_and_ptr = dlsym(handle, "mlx_bitwise_and");
if (mlx_bitwise_and_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_and\n");
return -1;
}
mlx_bitwise_invert_ptr = dlsym(handle, "mlx_bitwise_invert");
if (mlx_bitwise_invert_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_invert\n");
return -1;
}
mlx_bitwise_or_ptr = dlsym(handle, "mlx_bitwise_or");
if (mlx_bitwise_or_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_or\n");
return -1;
}
mlx_bitwise_xor_ptr = dlsym(handle, "mlx_bitwise_xor");
if (mlx_bitwise_xor_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_bitwise_xor\n");
return -1;
}
mlx_block_masked_mm_ptr = dlsym(handle, "mlx_block_masked_mm");
if (mlx_block_masked_mm_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_block_masked_mm\n");
return -1;
}
mlx_broadcast_arrays_ptr = dlsym(handle, "mlx_broadcast_arrays");
if (mlx_broadcast_arrays_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_broadcast_arrays\n");
return -1;
}
mlx_broadcast_to_ptr = dlsym(handle, "mlx_broadcast_to");
if (mlx_broadcast_to_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_broadcast_to\n");
return -1;
}
mlx_ceil_ptr = dlsym(handle, "mlx_ceil");
if (mlx_ceil_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_ceil\n");
return -1;
}
mlx_clip_ptr = dlsym(handle, "mlx_clip");
if (mlx_clip_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_clip\n");
return -1;
}
mlx_concatenate_axis_ptr = dlsym(handle, "mlx_concatenate_axis");
if (mlx_concatenate_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_concatenate_axis\n");
return -1;
}
mlx_concatenate_ptr = dlsym(handle, "mlx_concatenate");
if (mlx_concatenate_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_concatenate\n");
return -1;
}
mlx_conjugate_ptr = dlsym(handle, "mlx_conjugate");
if (mlx_conjugate_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_conjugate\n");
return -1;
}
mlx_contiguous_ptr = dlsym(handle, "mlx_contiguous");
if (mlx_contiguous_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_contiguous\n");
return -1;
}
mlx_conv1d_ptr = dlsym(handle, "mlx_conv1d");
if (mlx_conv1d_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_conv1d\n");
return -1;
}
mlx_conv2d_ptr = dlsym(handle, "mlx_conv2d");
if (mlx_conv2d_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_conv2d\n");
return -1;
}
mlx_conv3d_ptr = dlsym(handle, "mlx_conv3d");
if (mlx_conv3d_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_conv3d\n");
return -1;
}
mlx_conv_general_ptr = dlsym(handle, "mlx_conv_general");
if (mlx_conv_general_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_conv_general\n");
return -1;
}
mlx_conv_transpose1d_ptr = dlsym(handle, "mlx_conv_transpose1d");
if (mlx_conv_transpose1d_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_conv_transpose1d\n");
return -1;
}
mlx_conv_transpose2d_ptr = dlsym(handle, "mlx_conv_transpose2d");
if (mlx_conv_transpose2d_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_conv_transpose2d\n");
return -1;
}
mlx_conv_transpose3d_ptr = dlsym(handle, "mlx_conv_transpose3d");
if (mlx_conv_transpose3d_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_conv_transpose3d\n");
return -1;
}
mlx_copy_ptr = dlsym(handle, "mlx_copy");
if (mlx_copy_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_copy\n");
return -1;
}
mlx_cos_ptr = dlsym(handle, "mlx_cos");
if (mlx_cos_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_cos\n");
return -1;
}
mlx_cosh_ptr = dlsym(handle, "mlx_cosh");
if (mlx_cosh_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_cosh\n");
return -1;
}
mlx_cummax_ptr = dlsym(handle, "mlx_cummax");
if (mlx_cummax_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_cummax\n");
return -1;
}
mlx_cummin_ptr = dlsym(handle, "mlx_cummin");
if (mlx_cummin_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_cummin\n");
return -1;
}
mlx_cumprod_ptr = dlsym(handle, "mlx_cumprod");
if (mlx_cumprod_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_cumprod\n");
return -1;
}
mlx_cumsum_ptr = dlsym(handle, "mlx_cumsum");
if (mlx_cumsum_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_cumsum\n");
return -1;
}
mlx_degrees_ptr = dlsym(handle, "mlx_degrees");
if (mlx_degrees_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_degrees\n");
return -1;
}
mlx_depends_ptr = dlsym(handle, "mlx_depends");
if (mlx_depends_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_depends\n");
return -1;
}
mlx_dequantize_ptr = dlsym(handle, "mlx_dequantize");
if (mlx_dequantize_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_dequantize\n");
return -1;
}
mlx_diag_ptr = dlsym(handle, "mlx_diag");
if (mlx_diag_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_diag\n");
return -1;
}
mlx_diagonal_ptr = dlsym(handle, "mlx_diagonal");
if (mlx_diagonal_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_diagonal\n");
return -1;
}
mlx_divide_ptr = dlsym(handle, "mlx_divide");
if (mlx_divide_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_divide\n");
return -1;
}
mlx_divmod_ptr = dlsym(handle, "mlx_divmod");
if (mlx_divmod_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_divmod\n");
return -1;
}
mlx_einsum_ptr = dlsym(handle, "mlx_einsum");
if (mlx_einsum_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_einsum\n");
return -1;
}
mlx_equal_ptr = dlsym(handle, "mlx_equal");
if (mlx_equal_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_equal\n");
return -1;
}
mlx_erf_ptr = dlsym(handle, "mlx_erf");
if (mlx_erf_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_erf\n");
return -1;
}
mlx_erfinv_ptr = dlsym(handle, "mlx_erfinv");
if (mlx_erfinv_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_erfinv\n");
return -1;
}
mlx_exp_ptr = dlsym(handle, "mlx_exp");
if (mlx_exp_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_exp\n");
return -1;
}
mlx_expand_dims_axes_ptr = dlsym(handle, "mlx_expand_dims_axes");
if (mlx_expand_dims_axes_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_expand_dims_axes\n");
return -1;
}
mlx_expand_dims_ptr = dlsym(handle, "mlx_expand_dims");
if (mlx_expand_dims_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_expand_dims\n");
return -1;
}
mlx_expm1_ptr = dlsym(handle, "mlx_expm1");
if (mlx_expm1_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_expm1\n");
return -1;
}
mlx_eye_ptr = dlsym(handle, "mlx_eye");
if (mlx_eye_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_eye\n");
return -1;
}
mlx_flatten_ptr = dlsym(handle, "mlx_flatten");
if (mlx_flatten_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_flatten\n");
return -1;
}
mlx_floor_ptr = dlsym(handle, "mlx_floor");
if (mlx_floor_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_floor\n");
return -1;
}
mlx_floor_divide_ptr = dlsym(handle, "mlx_floor_divide");
if (mlx_floor_divide_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_floor_divide\n");
return -1;
}
mlx_from_fp8_ptr = dlsym(handle, "mlx_from_fp8");
if (mlx_from_fp8_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_from_fp8\n");
return -1;
}
mlx_full_ptr = dlsym(handle, "mlx_full");
if (mlx_full_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_full\n");
return -1;
}
mlx_full_like_ptr = dlsym(handle, "mlx_full_like");
if (mlx_full_like_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_full_like\n");
return -1;
}
mlx_gather_ptr = dlsym(handle, "mlx_gather");
if (mlx_gather_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_gather\n");
return -1;
}
mlx_gather_single_ptr = dlsym(handle, "mlx_gather_single");
if (mlx_gather_single_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_gather_single\n");
return -1;
}
mlx_gather_mm_ptr = dlsym(handle, "mlx_gather_mm");
if (mlx_gather_mm_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_gather_mm\n");
return -1;
}
mlx_gather_qmm_ptr = dlsym(handle, "mlx_gather_qmm");
if (mlx_gather_qmm_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_gather_qmm\n");
return -1;
}
mlx_greater_ptr = dlsym(handle, "mlx_greater");
if (mlx_greater_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_greater\n");
return -1;
}
mlx_greater_equal_ptr = dlsym(handle, "mlx_greater_equal");
if (mlx_greater_equal_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_greater_equal\n");
return -1;
}
mlx_hadamard_transform_ptr = dlsym(handle, "mlx_hadamard_transform");
if (mlx_hadamard_transform_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_hadamard_transform\n");
return -1;
}
mlx_identity_ptr = dlsym(handle, "mlx_identity");
if (mlx_identity_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_identity\n");
return -1;
}
mlx_imag_ptr = dlsym(handle, "mlx_imag");
if (mlx_imag_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_imag\n");
return -1;
}
mlx_inner_ptr = dlsym(handle, "mlx_inner");
if (mlx_inner_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_inner\n");
return -1;
}
mlx_isclose_ptr = dlsym(handle, "mlx_isclose");
if (mlx_isclose_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_isclose\n");
return -1;
}
mlx_isfinite_ptr = dlsym(handle, "mlx_isfinite");
if (mlx_isfinite_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_isfinite\n");
return -1;
}
mlx_isinf_ptr = dlsym(handle, "mlx_isinf");
if (mlx_isinf_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_isinf\n");
return -1;
}
mlx_isnan_ptr = dlsym(handle, "mlx_isnan");
if (mlx_isnan_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_isnan\n");
return -1;
}
mlx_isneginf_ptr = dlsym(handle, "mlx_isneginf");
if (mlx_isneginf_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_isneginf\n");
return -1;
}
mlx_isposinf_ptr = dlsym(handle, "mlx_isposinf");
if (mlx_isposinf_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_isposinf\n");
return -1;
}
mlx_kron_ptr = dlsym(handle, "mlx_kron");
if (mlx_kron_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_kron\n");
return -1;
}
mlx_left_shift_ptr = dlsym(handle, "mlx_left_shift");
if (mlx_left_shift_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_left_shift\n");
return -1;
}
mlx_less_ptr = dlsym(handle, "mlx_less");
if (mlx_less_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_less\n");
return -1;
}
mlx_less_equal_ptr = dlsym(handle, "mlx_less_equal");
if (mlx_less_equal_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_less_equal\n");
return -1;
}
mlx_linspace_ptr = dlsym(handle, "mlx_linspace");
if (mlx_linspace_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_linspace\n");
return -1;
}
mlx_log_ptr = dlsym(handle, "mlx_log");
if (mlx_log_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_log\n");
return -1;
}
mlx_log10_ptr = dlsym(handle, "mlx_log10");
if (mlx_log10_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_log10\n");
return -1;
}
mlx_log1p_ptr = dlsym(handle, "mlx_log1p");
if (mlx_log1p_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_log1p\n");
return -1;
}
mlx_log2_ptr = dlsym(handle, "mlx_log2");
if (mlx_log2_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_log2\n");
return -1;
}
mlx_logaddexp_ptr = dlsym(handle, "mlx_logaddexp");
if (mlx_logaddexp_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_logaddexp\n");
return -1;
}
mlx_logcumsumexp_ptr = dlsym(handle, "mlx_logcumsumexp");
if (mlx_logcumsumexp_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_logcumsumexp\n");
return -1;
}
mlx_logical_and_ptr = dlsym(handle, "mlx_logical_and");
if (mlx_logical_and_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_logical_and\n");
return -1;
}
mlx_logical_not_ptr = dlsym(handle, "mlx_logical_not");
if (mlx_logical_not_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_logical_not\n");
return -1;
}
mlx_logical_or_ptr = dlsym(handle, "mlx_logical_or");
if (mlx_logical_or_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_logical_or\n");
return -1;
}
mlx_logsumexp_axes_ptr = dlsym(handle, "mlx_logsumexp_axes");
if (mlx_logsumexp_axes_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_logsumexp_axes\n");
return -1;
}
mlx_logsumexp_axis_ptr = dlsym(handle, "mlx_logsumexp_axis");
if (mlx_logsumexp_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_logsumexp_axis\n");
return -1;
}
mlx_logsumexp_ptr = dlsym(handle, "mlx_logsumexp");
if (mlx_logsumexp_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_logsumexp\n");
return -1;
}
mlx_masked_scatter_ptr = dlsym(handle, "mlx_masked_scatter");
if (mlx_masked_scatter_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_masked_scatter\n");
return -1;
}
mlx_matmul_ptr = dlsym(handle, "mlx_matmul");
if (mlx_matmul_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_matmul\n");
return -1;
}
mlx_max_axes_ptr = dlsym(handle, "mlx_max_axes");
if (mlx_max_axes_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_max_axes\n");
return -1;
}
mlx_max_axis_ptr = dlsym(handle, "mlx_max_axis");
if (mlx_max_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_max_axis\n");
return -1;
}
mlx_max_ptr = dlsym(handle, "mlx_max");
if (mlx_max_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_max\n");
return -1;
}
mlx_maximum_ptr = dlsym(handle, "mlx_maximum");
if (mlx_maximum_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_maximum\n");
return -1;
}
mlx_mean_axes_ptr = dlsym(handle, "mlx_mean_axes");
if (mlx_mean_axes_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_mean_axes\n");
return -1;
}
mlx_mean_axis_ptr = dlsym(handle, "mlx_mean_axis");
if (mlx_mean_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_mean_axis\n");
return -1;
}
mlx_mean_ptr = dlsym(handle, "mlx_mean");
if (mlx_mean_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_mean\n");
return -1;
}
mlx_median_ptr = dlsym(handle, "mlx_median");
if (mlx_median_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_median\n");
return -1;
}
mlx_meshgrid_ptr = dlsym(handle, "mlx_meshgrid");
if (mlx_meshgrid_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_meshgrid\n");
return -1;
}
mlx_min_axes_ptr = dlsym(handle, "mlx_min_axes");
if (mlx_min_axes_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_min_axes\n");
return -1;
}
mlx_min_axis_ptr = dlsym(handle, "mlx_min_axis");
if (mlx_min_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_min_axis\n");
return -1;
}
mlx_min_ptr = dlsym(handle, "mlx_min");
if (mlx_min_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_min\n");
return -1;
}
mlx_minimum_ptr = dlsym(handle, "mlx_minimum");
if (mlx_minimum_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_minimum\n");
return -1;
}
mlx_moveaxis_ptr = dlsym(handle, "mlx_moveaxis");
if (mlx_moveaxis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_moveaxis\n");
return -1;
}
mlx_multiply_ptr = dlsym(handle, "mlx_multiply");
if (mlx_multiply_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_multiply\n");
return -1;
}
mlx_nan_to_num_ptr = dlsym(handle, "mlx_nan_to_num");
if (mlx_nan_to_num_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_nan_to_num\n");
return -1;
}
mlx_negative_ptr = dlsym(handle, "mlx_negative");
if (mlx_negative_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_negative\n");
return -1;
}
mlx_not_equal_ptr = dlsym(handle, "mlx_not_equal");
if (mlx_not_equal_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_not_equal\n");
return -1;
}
mlx_number_of_elements_ptr = dlsym(handle, "mlx_number_of_elements");
if (mlx_number_of_elements_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_number_of_elements\n");
return -1;
}
mlx_ones_ptr = dlsym(handle, "mlx_ones");
if (mlx_ones_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_ones\n");
return -1;
}
mlx_ones_like_ptr = dlsym(handle, "mlx_ones_like");
if (mlx_ones_like_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_ones_like\n");
return -1;
}
mlx_outer_ptr = dlsym(handle, "mlx_outer");
if (mlx_outer_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_outer\n");
return -1;
}
mlx_pad_ptr = dlsym(handle, "mlx_pad");
if (mlx_pad_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_pad\n");
return -1;
}
mlx_pad_symmetric_ptr = dlsym(handle, "mlx_pad_symmetric");
if (mlx_pad_symmetric_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_pad_symmetric\n");
return -1;
}
mlx_partition_axis_ptr = dlsym(handle, "mlx_partition_axis");
if (mlx_partition_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_partition_axis\n");
return -1;
}
mlx_partition_ptr = dlsym(handle, "mlx_partition");
if (mlx_partition_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_partition\n");
return -1;
}
mlx_power_ptr = dlsym(handle, "mlx_power");
if (mlx_power_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_power\n");
return -1;
}
mlx_prod_axes_ptr = dlsym(handle, "mlx_prod_axes");
if (mlx_prod_axes_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_prod_axes\n");
return -1;
}
mlx_prod_axis_ptr = dlsym(handle, "mlx_prod_axis");
if (mlx_prod_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_prod_axis\n");
return -1;
}
mlx_prod_ptr = dlsym(handle, "mlx_prod");
if (mlx_prod_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_prod\n");
return -1;
}
mlx_put_along_axis_ptr = dlsym(handle, "mlx_put_along_axis");
if (mlx_put_along_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_put_along_axis\n");
return -1;
}
mlx_qqmm_ptr = dlsym(handle, "mlx_qqmm");
if (mlx_qqmm_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_qqmm\n");
return -1;
}
mlx_quantize_ptr = dlsym(handle, "mlx_quantize");
if (mlx_quantize_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_quantize\n");
return -1;
}
mlx_quantized_matmul_ptr = dlsym(handle, "mlx_quantized_matmul");
if (mlx_quantized_matmul_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_quantized_matmul\n");
return -1;
}
mlx_radians_ptr = dlsym(handle, "mlx_radians");
if (mlx_radians_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_radians\n");
return -1;
}
mlx_real_ptr = dlsym(handle, "mlx_real");
if (mlx_real_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_real\n");
return -1;
}
mlx_reciprocal_ptr = dlsym(handle, "mlx_reciprocal");
if (mlx_reciprocal_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_reciprocal\n");
return -1;
}
mlx_remainder_ptr = dlsym(handle, "mlx_remainder");
if (mlx_remainder_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_remainder\n");
return -1;
}
mlx_repeat_axis_ptr = dlsym(handle, "mlx_repeat_axis");
if (mlx_repeat_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_repeat_axis\n");
return -1;
}
mlx_repeat_ptr = dlsym(handle, "mlx_repeat");
if (mlx_repeat_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_repeat\n");
return -1;
}
mlx_reshape_ptr = dlsym(handle, "mlx_reshape");
if (mlx_reshape_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_reshape\n");
return -1;
}
mlx_right_shift_ptr = dlsym(handle, "mlx_right_shift");
if (mlx_right_shift_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_right_shift\n");
return -1;
}
mlx_roll_axis_ptr = dlsym(handle, "mlx_roll_axis");
if (mlx_roll_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_roll_axis\n");
return -1;
}
mlx_roll_axes_ptr = dlsym(handle, "mlx_roll_axes");
if (mlx_roll_axes_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_roll_axes\n");
return -1;
}
mlx_roll_ptr = dlsym(handle, "mlx_roll");
if (mlx_roll_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_roll\n");
return -1;
}
mlx_round_ptr = dlsym(handle, "mlx_round");
if (mlx_round_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_round\n");
return -1;
}
mlx_rsqrt_ptr = dlsym(handle, "mlx_rsqrt");
if (mlx_rsqrt_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_rsqrt\n");
return -1;
}
mlx_scatter_ptr = dlsym(handle, "mlx_scatter");
if (mlx_scatter_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter\n");
return -1;
}
mlx_scatter_single_ptr = dlsym(handle, "mlx_scatter_single");
if (mlx_scatter_single_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_single\n");
return -1;
}
mlx_scatter_add_ptr = dlsym(handle, "mlx_scatter_add");
if (mlx_scatter_add_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_add\n");
return -1;
}
mlx_scatter_add_single_ptr = dlsym(handle, "mlx_scatter_add_single");
if (mlx_scatter_add_single_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_add_single\n");
return -1;
}
mlx_scatter_add_axis_ptr = dlsym(handle, "mlx_scatter_add_axis");
if (mlx_scatter_add_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_add_axis\n");
return -1;
}
mlx_scatter_max_ptr = dlsym(handle, "mlx_scatter_max");
if (mlx_scatter_max_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_max\n");
return -1;
}
mlx_scatter_max_single_ptr = dlsym(handle, "mlx_scatter_max_single");
if (mlx_scatter_max_single_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_max_single\n");
return -1;
}
mlx_scatter_min_ptr = dlsym(handle, "mlx_scatter_min");
if (mlx_scatter_min_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_min\n");
return -1;
}
mlx_scatter_min_single_ptr = dlsym(handle, "mlx_scatter_min_single");
if (mlx_scatter_min_single_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_min_single\n");
return -1;
}
mlx_scatter_prod_ptr = dlsym(handle, "mlx_scatter_prod");
if (mlx_scatter_prod_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_prod\n");
return -1;
}
mlx_scatter_prod_single_ptr = dlsym(handle, "mlx_scatter_prod_single");
if (mlx_scatter_prod_single_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_scatter_prod_single\n");
return -1;
}
mlx_segmented_mm_ptr = dlsym(handle, "mlx_segmented_mm");
if (mlx_segmented_mm_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_segmented_mm\n");
return -1;
}
mlx_sigmoid_ptr = dlsym(handle, "mlx_sigmoid");
if (mlx_sigmoid_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_sigmoid\n");
return -1;
}
mlx_sign_ptr = dlsym(handle, "mlx_sign");
if (mlx_sign_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_sign\n");
return -1;
}
mlx_sin_ptr = dlsym(handle, "mlx_sin");
if (mlx_sin_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_sin\n");
return -1;
}
mlx_sinh_ptr = dlsym(handle, "mlx_sinh");
if (mlx_sinh_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_sinh\n");
return -1;
}
mlx_slice_ptr = dlsym(handle, "mlx_slice");
if (mlx_slice_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_slice\n");
return -1;
}
mlx_slice_dynamic_ptr = dlsym(handle, "mlx_slice_dynamic");
if (mlx_slice_dynamic_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_slice_dynamic\n");
return -1;
}
mlx_slice_update_ptr = dlsym(handle, "mlx_slice_update");
if (mlx_slice_update_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_slice_update\n");
return -1;
}
mlx_slice_update_dynamic_ptr = dlsym(handle, "mlx_slice_update_dynamic");
if (mlx_slice_update_dynamic_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_slice_update_dynamic\n");
return -1;
}
mlx_softmax_axes_ptr = dlsym(handle, "mlx_softmax_axes");
if (mlx_softmax_axes_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_softmax_axes\n");
return -1;
}
mlx_softmax_axis_ptr = dlsym(handle, "mlx_softmax_axis");
if (mlx_softmax_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_softmax_axis\n");
return -1;
}
mlx_softmax_ptr = dlsym(handle, "mlx_softmax");
if (mlx_softmax_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_softmax\n");
return -1;
}
mlx_sort_axis_ptr = dlsym(handle, "mlx_sort_axis");
if (mlx_sort_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_sort_axis\n");
return -1;
}
mlx_sort_ptr = dlsym(handle, "mlx_sort");
if (mlx_sort_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_sort\n");
return -1;
}
mlx_split_ptr = dlsym(handle, "mlx_split");
if (mlx_split_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_split\n");
return -1;
}
mlx_split_sections_ptr = dlsym(handle, "mlx_split_sections");
if (mlx_split_sections_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_split_sections\n");
return -1;
}
mlx_sqrt_ptr = dlsym(handle, "mlx_sqrt");
if (mlx_sqrt_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_sqrt\n");
return -1;
}
mlx_square_ptr = dlsym(handle, "mlx_square");
if (mlx_square_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_square\n");
return -1;
}
mlx_squeeze_axes_ptr = dlsym(handle, "mlx_squeeze_axes");
if (mlx_squeeze_axes_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_squeeze_axes\n");
return -1;
}
mlx_squeeze_axis_ptr = dlsym(handle, "mlx_squeeze_axis");
if (mlx_squeeze_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_squeeze_axis\n");
return -1;
}
mlx_squeeze_ptr = dlsym(handle, "mlx_squeeze");
if (mlx_squeeze_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_squeeze\n");
return -1;
}
mlx_stack_axis_ptr = dlsym(handle, "mlx_stack_axis");
if (mlx_stack_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_stack_axis\n");
return -1;
}
mlx_stack_ptr = dlsym(handle, "mlx_stack");
if (mlx_stack_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_stack\n");
return -1;
}
mlx_std_axes_ptr = dlsym(handle, "mlx_std_axes");
if (mlx_std_axes_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_std_axes\n");
return -1;
}
mlx_std_axis_ptr = dlsym(handle, "mlx_std_axis");
if (mlx_std_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_std_axis\n");
return -1;
}
mlx_std_ptr = dlsym(handle, "mlx_std");
if (mlx_std_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_std\n");
return -1;
}
mlx_stop_gradient_ptr = dlsym(handle, "mlx_stop_gradient");
if (mlx_stop_gradient_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_stop_gradient\n");
return -1;
}
mlx_subtract_ptr = dlsym(handle, "mlx_subtract");
if (mlx_subtract_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_subtract\n");
return -1;
}
mlx_sum_axes_ptr = dlsym(handle, "mlx_sum_axes");
if (mlx_sum_axes_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_sum_axes\n");
return -1;
}
mlx_sum_axis_ptr = dlsym(handle, "mlx_sum_axis");
if (mlx_sum_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_sum_axis\n");
return -1;
}
mlx_sum_ptr = dlsym(handle, "mlx_sum");
if (mlx_sum_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_sum\n");
return -1;
}
mlx_swapaxes_ptr = dlsym(handle, "mlx_swapaxes");
if (mlx_swapaxes_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_swapaxes\n");
return -1;
}
mlx_take_axis_ptr = dlsym(handle, "mlx_take_axis");
if (mlx_take_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_take_axis\n");
return -1;
}
mlx_take_ptr = dlsym(handle, "mlx_take");
if (mlx_take_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_take\n");
return -1;
}
mlx_take_along_axis_ptr = dlsym(handle, "mlx_take_along_axis");
if (mlx_take_along_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_take_along_axis\n");
return -1;
}
mlx_tan_ptr = dlsym(handle, "mlx_tan");
if (mlx_tan_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_tan\n");
return -1;
}
mlx_tanh_ptr = dlsym(handle, "mlx_tanh");
if (mlx_tanh_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_tanh\n");
return -1;
}
mlx_tensordot_ptr = dlsym(handle, "mlx_tensordot");
if (mlx_tensordot_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_tensordot\n");
return -1;
}
mlx_tensordot_axis_ptr = dlsym(handle, "mlx_tensordot_axis");
if (mlx_tensordot_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_tensordot_axis\n");
return -1;
}
mlx_tile_ptr = dlsym(handle, "mlx_tile");
if (mlx_tile_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_tile\n");
return -1;
}
mlx_to_fp8_ptr = dlsym(handle, "mlx_to_fp8");
if (mlx_to_fp8_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_to_fp8\n");
return -1;
}
mlx_topk_axis_ptr = dlsym(handle, "mlx_topk_axis");
if (mlx_topk_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_topk_axis\n");
return -1;
}
mlx_topk_ptr = dlsym(handle, "mlx_topk");
if (mlx_topk_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_topk\n");
return -1;
}
mlx_trace_ptr = dlsym(handle, "mlx_trace");
if (mlx_trace_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_trace\n");
return -1;
}
mlx_transpose_axes_ptr = dlsym(handle, "mlx_transpose_axes");
if (mlx_transpose_axes_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_transpose_axes\n");
return -1;
}
mlx_transpose_ptr = dlsym(handle, "mlx_transpose");
if (mlx_transpose_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_transpose\n");
return -1;
}
mlx_tri_ptr = dlsym(handle, "mlx_tri");
if (mlx_tri_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_tri\n");
return -1;
}
mlx_tril_ptr = dlsym(handle, "mlx_tril");
if (mlx_tril_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_tril\n");
return -1;
}
mlx_triu_ptr = dlsym(handle, "mlx_triu");
if (mlx_triu_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_triu\n");
return -1;
}
mlx_unflatten_ptr = dlsym(handle, "mlx_unflatten");
if (mlx_unflatten_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_unflatten\n");
return -1;
}
mlx_var_axes_ptr = dlsym(handle, "mlx_var_axes");
if (mlx_var_axes_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_var_axes\n");
return -1;
}
mlx_var_axis_ptr = dlsym(handle, "mlx_var_axis");
if (mlx_var_axis_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_var_axis\n");
return -1;
}
mlx_var_ptr = dlsym(handle, "mlx_var");
if (mlx_var_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_var\n");
return -1;
}
mlx_view_ptr = dlsym(handle, "mlx_view");
if (mlx_view_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_view\n");
return -1;
}
mlx_where_ptr = dlsym(handle, "mlx_where");
if (mlx_where_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_where\n");
return -1;
}
mlx_zeros_ptr = dlsym(handle, "mlx_zeros");
if (mlx_zeros_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_zeros\n");
return -1;
}
mlx_zeros_like_ptr = dlsym(handle, "mlx_zeros_like");
if (mlx_zeros_like_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_zeros_like\n");
return -1;
}
mlx_random_bernoulli_ptr = dlsym(handle, "mlx_random_bernoulli");
if (mlx_random_bernoulli_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_bernoulli\n");
return -1;
}
mlx_random_bits_ptr = dlsym(handle, "mlx_random_bits");
if (mlx_random_bits_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_bits\n");
return -1;
}
mlx_random_categorical_shape_ptr = dlsym(handle, "mlx_random_categorical_shape");
if (mlx_random_categorical_shape_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_categorical_shape\n");
return -1;
}
mlx_random_categorical_num_samples_ptr = dlsym(handle, "mlx_random_categorical_num_samples");
if (mlx_random_categorical_num_samples_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_categorical_num_samples\n");
return -1;
}
mlx_random_categorical_ptr = dlsym(handle, "mlx_random_categorical");
if (mlx_random_categorical_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_categorical\n");
return -1;
}
mlx_random_gumbel_ptr = dlsym(handle, "mlx_random_gumbel");
if (mlx_random_gumbel_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_gumbel\n");
return -1;
}
mlx_random_key_ptr = dlsym(handle, "mlx_random_key");
if (mlx_random_key_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_key\n");
return -1;
}
mlx_random_laplace_ptr = dlsym(handle, "mlx_random_laplace");
if (mlx_random_laplace_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_laplace\n");
return -1;
}
mlx_random_multivariate_normal_ptr = dlsym(handle, "mlx_random_multivariate_normal");
if (mlx_random_multivariate_normal_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_multivariate_normal\n");
return -1;
}
mlx_random_normal_broadcast_ptr = dlsym(handle, "mlx_random_normal_broadcast");
if (mlx_random_normal_broadcast_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_normal_broadcast\n");
return -1;
}
mlx_random_normal_ptr = dlsym(handle, "mlx_random_normal");
if (mlx_random_normal_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_normal\n");
return -1;
}
mlx_random_permutation_ptr = dlsym(handle, "mlx_random_permutation");
if (mlx_random_permutation_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_permutation\n");
return -1;
}
mlx_random_permutation_arange_ptr = dlsym(handle, "mlx_random_permutation_arange");
if (mlx_random_permutation_arange_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_permutation_arange\n");
return -1;
}
mlx_random_randint_ptr = dlsym(handle, "mlx_random_randint");
if (mlx_random_randint_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_randint\n");
return -1;
}
mlx_random_seed_ptr = dlsym(handle, "mlx_random_seed");
if (mlx_random_seed_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_seed\n");
return -1;
}
mlx_random_split_num_ptr = dlsym(handle, "mlx_random_split_num");
if (mlx_random_split_num_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_split_num\n");
return -1;
}
mlx_random_split_ptr = dlsym(handle, "mlx_random_split");
if (mlx_random_split_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_split\n");
return -1;
}
mlx_random_truncated_normal_ptr = dlsym(handle, "mlx_random_truncated_normal");
if (mlx_random_truncated_normal_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_truncated_normal\n");
return -1;
}
mlx_random_uniform_ptr = dlsym(handle, "mlx_random_uniform");
if (mlx_random_uniform_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_random_uniform\n");
return -1;
}
mlx_stream_new_ptr = dlsym(handle, "mlx_stream_new");
if (mlx_stream_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_new\n");
return -1;
}
mlx_stream_new_device_ptr = dlsym(handle, "mlx_stream_new_device");
if (mlx_stream_new_device_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_new_device\n");
return -1;
}
mlx_stream_set_ptr = dlsym(handle, "mlx_stream_set");
if (mlx_stream_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_set\n");
return -1;
}
mlx_stream_free_ptr = dlsym(handle, "mlx_stream_free");
if (mlx_stream_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_free\n");
return -1;
}
mlx_stream_tostring_ptr = dlsym(handle, "mlx_stream_tostring");
if (mlx_stream_tostring_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_tostring\n");
return -1;
}
mlx_stream_equal_ptr = dlsym(handle, "mlx_stream_equal");
if (mlx_stream_equal_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_equal\n");
return -1;
}
mlx_stream_get_device_ptr = dlsym(handle, "mlx_stream_get_device");
if (mlx_stream_get_device_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_get_device\n");
return -1;
}
mlx_stream_get_index_ptr = dlsym(handle, "mlx_stream_get_index");
if (mlx_stream_get_index_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_stream_get_index\n");
return -1;
}
mlx_synchronize_ptr = dlsym(handle, "mlx_synchronize");
if (mlx_synchronize_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_synchronize\n");
return -1;
}
mlx_get_default_stream_ptr = dlsym(handle, "mlx_get_default_stream");
if (mlx_get_default_stream_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_get_default_stream\n");
return -1;
}
mlx_set_default_stream_ptr = dlsym(handle, "mlx_set_default_stream");
if (mlx_set_default_stream_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_set_default_stream\n");
return -1;
}
mlx_default_cpu_stream_new_ptr = dlsym(handle, "mlx_default_cpu_stream_new");
if (mlx_default_cpu_stream_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_default_cpu_stream_new\n");
return -1;
}
mlx_default_gpu_stream_new_ptr = dlsym(handle, "mlx_default_gpu_stream_new");
if (mlx_default_gpu_stream_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_default_gpu_stream_new\n");
return -1;
}
mlx_string_new_ptr = dlsym(handle, "mlx_string_new");
if (mlx_string_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_string_new\n");
return -1;
}
mlx_string_new_data_ptr = dlsym(handle, "mlx_string_new_data");
if (mlx_string_new_data_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_string_new_data\n");
return -1;
}
mlx_string_set_ptr = dlsym(handle, "mlx_string_set");
if (mlx_string_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_string_set\n");
return -1;
}
mlx_string_data_ptr = dlsym(handle, "mlx_string_data");
if (mlx_string_data_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_string_data\n");
return -1;
}
mlx_string_free_ptr = dlsym(handle, "mlx_string_free");
if (mlx_string_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_string_free\n");
return -1;
}
mlx_async_eval_ptr = dlsym(handle, "mlx_async_eval");
if (mlx_async_eval_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_async_eval\n");
return -1;
}
mlx_checkpoint_ptr = dlsym(handle, "mlx_checkpoint");
if (mlx_checkpoint_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_checkpoint\n");
return -1;
}
mlx_custom_function_ptr = dlsym(handle, "mlx_custom_function");
if (mlx_custom_function_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_custom_function\n");
return -1;
}
mlx_custom_vjp_ptr = dlsym(handle, "mlx_custom_vjp");
if (mlx_custom_vjp_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_custom_vjp\n");
return -1;
}
mlx_eval_ptr = dlsym(handle, "mlx_eval");
if (mlx_eval_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_eval\n");
return -1;
}
mlx_jvp_ptr = dlsym(handle, "mlx_jvp");
if (mlx_jvp_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_jvp\n");
return -1;
}
mlx_value_and_grad_ptr = dlsym(handle, "mlx_value_and_grad");
if (mlx_value_and_grad_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_value_and_grad\n");
return -1;
}
mlx_vjp_ptr = dlsym(handle, "mlx_vjp");
if (mlx_vjp_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vjp\n");
return -1;
}
mlx_detail_vmap_replace_ptr = dlsym(handle, "mlx_detail_vmap_replace");
if (mlx_detail_vmap_replace_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_detail_vmap_replace\n");
return -1;
}
mlx_detail_vmap_trace_ptr = dlsym(handle, "mlx_detail_vmap_trace");
if (mlx_detail_vmap_trace_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_detail_vmap_trace\n");
return -1;
}
mlx_vector_array_new_ptr = dlsym(handle, "mlx_vector_array_new");
if (mlx_vector_array_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_new\n");
return -1;
}
mlx_vector_array_set_ptr = dlsym(handle, "mlx_vector_array_set");
if (mlx_vector_array_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_set\n");
return -1;
}
mlx_vector_array_free_ptr = dlsym(handle, "mlx_vector_array_free");
if (mlx_vector_array_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_free\n");
return -1;
}
mlx_vector_array_new_data_ptr = dlsym(handle, "mlx_vector_array_new_data");
if (mlx_vector_array_new_data_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_new_data\n");
return -1;
}
mlx_vector_array_new_value_ptr = dlsym(handle, "mlx_vector_array_new_value");
if (mlx_vector_array_new_value_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_new_value\n");
return -1;
}
mlx_vector_array_set_data_ptr = dlsym(handle, "mlx_vector_array_set_data");
if (mlx_vector_array_set_data_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_set_data\n");
return -1;
}
mlx_vector_array_set_value_ptr = dlsym(handle, "mlx_vector_array_set_value");
if (mlx_vector_array_set_value_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_set_value\n");
return -1;
}
mlx_vector_array_append_data_ptr = dlsym(handle, "mlx_vector_array_append_data");
if (mlx_vector_array_append_data_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_append_data\n");
return -1;
}
mlx_vector_array_append_value_ptr = dlsym(handle, "mlx_vector_array_append_value");
if (mlx_vector_array_append_value_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_append_value\n");
return -1;
}
mlx_vector_array_size_ptr = dlsym(handle, "mlx_vector_array_size");
if (mlx_vector_array_size_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_size\n");
return -1;
}
mlx_vector_array_get_ptr = dlsym(handle, "mlx_vector_array_get");
if (mlx_vector_array_get_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_array_get\n");
return -1;
}
mlx_vector_vector_array_new_ptr = dlsym(handle, "mlx_vector_vector_array_new");
if (mlx_vector_vector_array_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_new\n");
return -1;
}
mlx_vector_vector_array_set_ptr = dlsym(handle, "mlx_vector_vector_array_set");
if (mlx_vector_vector_array_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_set\n");
return -1;
}
mlx_vector_vector_array_free_ptr = dlsym(handle, "mlx_vector_vector_array_free");
if (mlx_vector_vector_array_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_free\n");
return -1;
}
mlx_vector_vector_array_new_data_ptr = dlsym(handle, "mlx_vector_vector_array_new_data");
if (mlx_vector_vector_array_new_data_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_new_data\n");
return -1;
}
mlx_vector_vector_array_new_value_ptr = dlsym(handle, "mlx_vector_vector_array_new_value");
if (mlx_vector_vector_array_new_value_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_new_value\n");
return -1;
}
mlx_vector_vector_array_set_data_ptr = dlsym(handle, "mlx_vector_vector_array_set_data");
if (mlx_vector_vector_array_set_data_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_set_data\n");
return -1;
}
mlx_vector_vector_array_set_value_ptr = dlsym(handle, "mlx_vector_vector_array_set_value");
if (mlx_vector_vector_array_set_value_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_set_value\n");
return -1;
}
mlx_vector_vector_array_append_data_ptr = dlsym(handle, "mlx_vector_vector_array_append_data");
if (mlx_vector_vector_array_append_data_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_append_data\n");
return -1;
}
mlx_vector_vector_array_append_value_ptr = dlsym(handle, "mlx_vector_vector_array_append_value");
if (mlx_vector_vector_array_append_value_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_append_value\n");
return -1;
}
mlx_vector_vector_array_size_ptr = dlsym(handle, "mlx_vector_vector_array_size");
if (mlx_vector_vector_array_size_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_size\n");
return -1;
}
mlx_vector_vector_array_get_ptr = dlsym(handle, "mlx_vector_vector_array_get");
if (mlx_vector_vector_array_get_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_vector_array_get\n");
return -1;
}
mlx_vector_int_new_ptr = dlsym(handle, "mlx_vector_int_new");
if (mlx_vector_int_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_new\n");
return -1;
}
mlx_vector_int_set_ptr = dlsym(handle, "mlx_vector_int_set");
if (mlx_vector_int_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_set\n");
return -1;
}
mlx_vector_int_free_ptr = dlsym(handle, "mlx_vector_int_free");
if (mlx_vector_int_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_free\n");
return -1;
}
mlx_vector_int_new_data_ptr = dlsym(handle, "mlx_vector_int_new_data");
if (mlx_vector_int_new_data_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_new_data\n");
return -1;
}
mlx_vector_int_new_value_ptr = dlsym(handle, "mlx_vector_int_new_value");
if (mlx_vector_int_new_value_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_new_value\n");
return -1;
}
mlx_vector_int_set_data_ptr = dlsym(handle, "mlx_vector_int_set_data");
if (mlx_vector_int_set_data_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_set_data\n");
return -1;
}
mlx_vector_int_set_value_ptr = dlsym(handle, "mlx_vector_int_set_value");
if (mlx_vector_int_set_value_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_set_value\n");
return -1;
}
mlx_vector_int_append_data_ptr = dlsym(handle, "mlx_vector_int_append_data");
if (mlx_vector_int_append_data_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_append_data\n");
return -1;
}
mlx_vector_int_append_value_ptr = dlsym(handle, "mlx_vector_int_append_value");
if (mlx_vector_int_append_value_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_append_value\n");
return -1;
}
mlx_vector_int_size_ptr = dlsym(handle, "mlx_vector_int_size");
if (mlx_vector_int_size_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_size\n");
return -1;
}
mlx_vector_int_get_ptr = dlsym(handle, "mlx_vector_int_get");
if (mlx_vector_int_get_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_int_get\n");
return -1;
}
mlx_vector_string_new_ptr = dlsym(handle, "mlx_vector_string_new");
if (mlx_vector_string_new_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_new\n");
return -1;
}
mlx_vector_string_set_ptr = dlsym(handle, "mlx_vector_string_set");
if (mlx_vector_string_set_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_set\n");
return -1;
}
mlx_vector_string_free_ptr = dlsym(handle, "mlx_vector_string_free");
if (mlx_vector_string_free_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_free\n");
return -1;
}
mlx_vector_string_new_data_ptr = dlsym(handle, "mlx_vector_string_new_data");
if (mlx_vector_string_new_data_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_new_data\n");
return -1;
}
mlx_vector_string_new_value_ptr = dlsym(handle, "mlx_vector_string_new_value");
if (mlx_vector_string_new_value_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_new_value\n");
return -1;
}
mlx_vector_string_set_data_ptr = dlsym(handle, "mlx_vector_string_set_data");
if (mlx_vector_string_set_data_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_set_data\n");
return -1;
}
mlx_vector_string_set_value_ptr = dlsym(handle, "mlx_vector_string_set_value");
if (mlx_vector_string_set_value_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_set_value\n");
return -1;
}
mlx_vector_string_append_data_ptr = dlsym(handle, "mlx_vector_string_append_data");
if (mlx_vector_string_append_data_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_append_data\n");
return -1;
}
mlx_vector_string_append_value_ptr = dlsym(handle, "mlx_vector_string_append_value");
if (mlx_vector_string_append_value_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_append_value\n");
return -1;
}
mlx_vector_string_size_ptr = dlsym(handle, "mlx_vector_string_size");
if (mlx_vector_string_size_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_size\n");
return -1;
}
mlx_vector_string_get_ptr = dlsym(handle, "mlx_vector_string_get");
if (mlx_vector_string_get_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_vector_string_get\n");
return -1;
}
mlx_version_ptr = dlsym(handle, "mlx_version");
if (mlx_version_ptr == NULL) {
fprintf(stderr, "MLX: Failed to load symbol: mlx_version\n");
return -1;
}
return 0;
}
// Wrapper function implementations that call through function pointers
size_t mlx_dtype_size(mlx_dtype dtype) {
return mlx_dtype_size_ptr(dtype);
}
int mlx_array_tostring(mlx_string* str, const mlx_array arr) {
return mlx_array_tostring_ptr(str, arr);
}
mlx_array mlx_array_new(void) {
return mlx_array_new_ptr();
}
int mlx_array_free(mlx_array arr) {
return mlx_array_free_ptr(arr);
}
mlx_array mlx_array_new_bool(bool val) {
return mlx_array_new_bool_ptr(val);
}
mlx_array mlx_array_new_int(int val) {
return mlx_array_new_int_ptr(val);
}
mlx_array mlx_array_new_float32(float val) {
return mlx_array_new_float32_ptr(val);
}
mlx_array mlx_array_new_float(float val) {
return mlx_array_new_float_ptr(val);
}
mlx_array mlx_array_new_float64(double val) {
return mlx_array_new_float64_ptr(val);
}
mlx_array mlx_array_new_double(double val) {
return mlx_array_new_double_ptr(val);
}
mlx_array mlx_array_new_complex(float real_val, float imag_val) {
return mlx_array_new_complex_ptr(real_val, imag_val);
}
mlx_array mlx_array_new_data(const void* data, const int* shape, int dim, mlx_dtype dtype) {
return mlx_array_new_data_ptr(data, shape, dim, dtype);
}
int mlx_array_set(mlx_array* arr, const mlx_array src) {
return mlx_array_set_ptr(arr, src);
}
int mlx_array_set_bool(mlx_array* arr, bool val) {
return mlx_array_set_bool_ptr(arr, val);
}
int mlx_array_set_int(mlx_array* arr, int val) {
return mlx_array_set_int_ptr(arr, val);
}
int mlx_array_set_float32(mlx_array* arr, float val) {
return mlx_array_set_float32_ptr(arr, val);
}
int mlx_array_set_float(mlx_array* arr, float val) {
return mlx_array_set_float_ptr(arr, val);
}
int mlx_array_set_float64(mlx_array* arr, double val) {
return mlx_array_set_float64_ptr(arr, val);
}
int mlx_array_set_double(mlx_array* arr, double val) {
return mlx_array_set_double_ptr(arr, val);
}
int mlx_array_set_complex(mlx_array* arr, float real_val, float imag_val) {
return mlx_array_set_complex_ptr(arr, real_val, imag_val);
}
int mlx_array_set_data(mlx_array* arr, const void* data, const int* shape, int dim, mlx_dtype dtype) {
return mlx_array_set_data_ptr(arr, data, shape, dim, dtype);
}
size_t mlx_array_itemsize(const mlx_array arr) {
return mlx_array_itemsize_ptr(arr);
}
size_t mlx_array_size(const mlx_array arr) {
return mlx_array_size_ptr(arr);
}
size_t mlx_array_nbytes(const mlx_array arr) {
return mlx_array_nbytes_ptr(arr);
}
size_t mlx_array_ndim(const mlx_array arr) {
return mlx_array_ndim_ptr(arr);
}
const int* mlx_array_shape(const mlx_array arr) {
return mlx_array_shape_ptr(arr);
}
const size_t* mlx_array_strides(const mlx_array arr) {
return mlx_array_strides_ptr(arr);
}
int mlx_array_dim(const mlx_array arr, int dim) {
return mlx_array_dim_ptr(arr, dim);
}
mlx_dtype mlx_array_dtype(const mlx_array arr) {
return mlx_array_dtype_ptr(arr);
}
int mlx_array_eval(mlx_array arr) {
return mlx_array_eval_ptr(arr);
}
int mlx_array_item_bool(bool* res, const mlx_array arr) {
return mlx_array_item_bool_ptr(res, arr);
}
int mlx_array_item_uint8(uint8_t* res, const mlx_array arr) {
return mlx_array_item_uint8_ptr(res, arr);
}
int mlx_array_item_uint16(uint16_t* res, const mlx_array arr) {
return mlx_array_item_uint16_ptr(res, arr);
}
int mlx_array_item_uint32(uint32_t* res, const mlx_array arr) {
return mlx_array_item_uint32_ptr(res, arr);
}
int mlx_array_item_uint64(uint64_t* res, const mlx_array arr) {
return mlx_array_item_uint64_ptr(res, arr);
}
int mlx_array_item_int8(int8_t* res, const mlx_array arr) {
return mlx_array_item_int8_ptr(res, arr);
}
int mlx_array_item_int16(int16_t* res, const mlx_array arr) {
return mlx_array_item_int16_ptr(res, arr);
}
int mlx_array_item_int32(int32_t* res, const mlx_array arr) {
return mlx_array_item_int32_ptr(res, arr);
}
int mlx_array_item_int64(int64_t* res, const mlx_array arr) {
return mlx_array_item_int64_ptr(res, arr);
}
int mlx_array_item_float32(float* res, const mlx_array arr) {
return mlx_array_item_float32_ptr(res, arr);
}
int mlx_array_item_float64(double* res, const mlx_array arr) {
return mlx_array_item_float64_ptr(res, arr);
}
int mlx_array_item_complex64(float _Complex* res, const mlx_array arr) {
return mlx_array_item_complex64_ptr(res, arr);
}
#if defined(__aarch64__) || defined(_M_ARM64)
int mlx_array_item_float16(float16_t* res, const mlx_array arr) {
return mlx_array_item_float16_ptr(res, arr);
}
#endif
#if defined(__aarch64__) || defined(_M_ARM64)
int mlx_array_item_bfloat16(bfloat16_t* res, const mlx_array arr) {
return mlx_array_item_bfloat16_ptr(res, arr);
}
#endif
const bool* mlx_array_data_bool(const mlx_array arr) {
return mlx_array_data_bool_ptr(arr);
}
const uint8_t* mlx_array_data_uint8(const mlx_array arr) {
return mlx_array_data_uint8_ptr(arr);
}
const uint16_t* mlx_array_data_uint16(const mlx_array arr) {
return mlx_array_data_uint16_ptr(arr);
}
const uint32_t* mlx_array_data_uint32(const mlx_array arr) {
return mlx_array_data_uint32_ptr(arr);
}
const uint64_t* mlx_array_data_uint64(const mlx_array arr) {
return mlx_array_data_uint64_ptr(arr);
}
const int8_t* mlx_array_data_int8(const mlx_array arr) {
return mlx_array_data_int8_ptr(arr);
}
const int16_t* mlx_array_data_int16(const mlx_array arr) {
return mlx_array_data_int16_ptr(arr);
}
const int32_t* mlx_array_data_int32(const mlx_array arr) {
return mlx_array_data_int32_ptr(arr);
}
const int64_t* mlx_array_data_int64(const mlx_array arr) {
return mlx_array_data_int64_ptr(arr);
}
const float* mlx_array_data_float32(const mlx_array arr) {
return mlx_array_data_float32_ptr(arr);
}
const double* mlx_array_data_float64(const mlx_array arr) {
return mlx_array_data_float64_ptr(arr);
}
const float _Complex* mlx_array_data_complex64(const mlx_array arr) {
return mlx_array_data_complex64_ptr(arr);
}
#if defined(__aarch64__) || defined(_M_ARM64)
const float16_t* mlx_array_data_float16(const mlx_array arr) {
return mlx_array_data_float16_ptr(arr);
}
#endif
#if defined(__aarch64__) || defined(_M_ARM64)
const bfloat16_t* mlx_array_data_bfloat16(const mlx_array arr) {
return mlx_array_data_bfloat16_ptr(arr);
}
#endif
int _mlx_array_is_available(bool* res, const mlx_array arr) {
return _mlx_array_is_available_ptr(res, arr);
}
int _mlx_array_wait(const mlx_array arr) {
return _mlx_array_wait_ptr(arr);
}
int _mlx_array_is_contiguous(bool* res, const mlx_array arr) {
return _mlx_array_is_contiguous_ptr(res, arr);
}
int _mlx_array_is_row_contiguous(bool* res, const mlx_array arr) {
return _mlx_array_is_row_contiguous_ptr(res, arr);
}
int _mlx_array_is_col_contiguous(bool* res, const mlx_array arr) {
return _mlx_array_is_col_contiguous_ptr(res, arr);
}
mlx_closure mlx_closure_new(void) {
return mlx_closure_new_ptr();
}
int mlx_closure_free(mlx_closure cls) {
return mlx_closure_free_ptr(cls);
}
mlx_closure mlx_closure_new_func(int (*fun)(mlx_vector_array*, const mlx_vector_array)) {
return mlx_closure_new_func_ptr(fun);
}
mlx_closure mlx_closure_new_func_payload(int (*fun)(mlx_vector_array*, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)) {
return mlx_closure_new_func_payload_ptr(fun, payload, dtor);
}
int mlx_closure_set(mlx_closure* cls, const mlx_closure src) {
return mlx_closure_set_ptr(cls, src);
}
int mlx_closure_apply(mlx_vector_array* res, mlx_closure cls, const mlx_vector_array input) {
return mlx_closure_apply_ptr(res, cls, input);
}
mlx_closure mlx_closure_new_unary(int (*fun)(mlx_array*, const mlx_array)) {
return mlx_closure_new_unary_ptr(fun);
}
mlx_closure_kwargs mlx_closure_kwargs_new(void) {
return mlx_closure_kwargs_new_ptr();
}
int mlx_closure_kwargs_free(mlx_closure_kwargs cls) {
return mlx_closure_kwargs_free_ptr(cls);
}
mlx_closure_kwargs mlx_closure_kwargs_new_func(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_map_string_to_array)) {
return mlx_closure_kwargs_new_func_ptr(fun);
}
mlx_closure_kwargs mlx_closure_kwargs_new_func_payload(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_map_string_to_array, void*), void* payload, void (*dtor)(void*)) {
return mlx_closure_kwargs_new_func_payload_ptr(fun, payload, dtor);
}
int mlx_closure_kwargs_set(mlx_closure_kwargs* cls, const mlx_closure_kwargs src) {
return mlx_closure_kwargs_set_ptr(cls, src);
}
int mlx_closure_kwargs_apply(mlx_vector_array* res, mlx_closure_kwargs cls, const mlx_vector_array input_0, const mlx_map_string_to_array input_1) {
return mlx_closure_kwargs_apply_ptr(res, cls, input_0, input_1);
}
mlx_closure_value_and_grad mlx_closure_value_and_grad_new(void) {
return mlx_closure_value_and_grad_new_ptr();
}
int mlx_closure_value_and_grad_free(mlx_closure_value_and_grad cls) {
return mlx_closure_value_and_grad_free_ptr(cls);
}
mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func(int (*fun)(mlx_vector_array*, mlx_vector_array*, const mlx_vector_array)) {
return mlx_closure_value_and_grad_new_func_ptr(fun);
}
mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func_payload(int (*fun)( mlx_vector_array*, mlx_vector_array*, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)) {
return mlx_closure_value_and_grad_new_func_payload_ptr(fun, payload, dtor);
}
int mlx_closure_value_and_grad_set(mlx_closure_value_and_grad* cls, const mlx_closure_value_and_grad src) {
return mlx_closure_value_and_grad_set_ptr(cls, src);
}
int mlx_closure_value_and_grad_apply(mlx_vector_array* res_0, mlx_vector_array* res_1, mlx_closure_value_and_grad cls, const mlx_vector_array input) {
return mlx_closure_value_and_grad_apply_ptr(res_0, res_1, cls, input);
}
mlx_closure_custom mlx_closure_custom_new(void) {
return mlx_closure_custom_new_ptr();
}
int mlx_closure_custom_free(mlx_closure_custom cls) {
return mlx_closure_custom_free_ptr(cls);
}
mlx_closure_custom mlx_closure_custom_new_func(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const mlx_vector_array)) {
return mlx_closure_custom_new_func_ptr(fun);
}
mlx_closure_custom mlx_closure_custom_new_func_payload(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const mlx_vector_array, void*), void* payload, void (*dtor)(void*)) {
return mlx_closure_custom_new_func_payload_ptr(fun, payload, dtor);
}
int mlx_closure_custom_set(mlx_closure_custom* cls, const mlx_closure_custom src) {
return mlx_closure_custom_set_ptr(cls, src);
}
int mlx_closure_custom_apply(mlx_vector_array* res, mlx_closure_custom cls, const mlx_vector_array input_0, const mlx_vector_array input_1, const mlx_vector_array input_2) {
return mlx_closure_custom_apply_ptr(res, cls, input_0, input_1, input_2);
}
mlx_closure_custom_jvp mlx_closure_custom_jvp_new(void) {
return mlx_closure_custom_jvp_new_ptr();
}
int mlx_closure_custom_jvp_free(mlx_closure_custom_jvp cls) {
return mlx_closure_custom_jvp_free_ptr(cls);
}
mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const int*, size_t _num)) {
return mlx_closure_custom_jvp_new_func_ptr(fun);
}
mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func_payload(int (*fun)( mlx_vector_array*, const mlx_vector_array, const mlx_vector_array, const int*, size_t _num, void*), void* payload, void (*dtor)(void*)) {
return mlx_closure_custom_jvp_new_func_payload_ptr(fun, payload, dtor);
}
int mlx_closure_custom_jvp_set(mlx_closure_custom_jvp* cls, const mlx_closure_custom_jvp src) {
return mlx_closure_custom_jvp_set_ptr(cls, src);
}
int mlx_closure_custom_jvp_apply(mlx_vector_array* res, mlx_closure_custom_jvp cls, const mlx_vector_array input_0, const mlx_vector_array input_1, const int* input_2, size_t input_2_num) {
return mlx_closure_custom_jvp_apply_ptr(res, cls, input_0, input_1, input_2, input_2_num);
}
mlx_closure_custom_vmap mlx_closure_custom_vmap_new(void) {
return mlx_closure_custom_vmap_new_ptr();
}
int mlx_closure_custom_vmap_free(mlx_closure_custom_vmap cls) {
return mlx_closure_custom_vmap_free_ptr(cls);
}
mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func(int (*fun)( mlx_vector_array*, mlx_vector_int*, const mlx_vector_array, const int*, size_t _num)) {
return mlx_closure_custom_vmap_new_func_ptr(fun);
}
mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func_payload(int (*fun)( mlx_vector_array*, mlx_vector_int*, const mlx_vector_array, const int*, size_t _num, void*), void* payload, void (*dtor)(void*)) {
return mlx_closure_custom_vmap_new_func_payload_ptr(fun, payload, dtor);
}
int mlx_closure_custom_vmap_set(mlx_closure_custom_vmap* cls, const mlx_closure_custom_vmap src) {
return mlx_closure_custom_vmap_set_ptr(cls, src);
}
int mlx_closure_custom_vmap_apply(mlx_vector_array* res_0, mlx_vector_int* res_1, mlx_closure_custom_vmap cls, const mlx_vector_array input_0, const int* input_1, size_t input_1_num) {
return mlx_closure_custom_vmap_apply_ptr(res_0, res_1, cls, input_0, input_1, input_1_num);
}
int mlx_compile(mlx_closure* res, const mlx_closure fun, bool shapeless) {
return mlx_compile_ptr(res, fun, shapeless);
}
int mlx_detail_compile(mlx_closure* res, const mlx_closure fun, uintptr_t fun_id, bool shapeless, const uint64_t* constants, size_t constants_num) {
return mlx_detail_compile_ptr(res, fun, fun_id, shapeless, constants, constants_num);
}
int mlx_detail_compile_clear_cache(void) {
return mlx_detail_compile_clear_cache_ptr();
}
int mlx_detail_compile_erase(uintptr_t fun_id) {
return mlx_detail_compile_erase_ptr(fun_id);
}
int mlx_disable_compile(void) {
return mlx_disable_compile_ptr();
}
int mlx_enable_compile(void) {
return mlx_enable_compile_ptr();
}
int mlx_set_compile_mode(mlx_compile_mode mode) {
return mlx_set_compile_mode_ptr(mode);
}
mlx_device mlx_device_new(void) {
return mlx_device_new_ptr();
}
mlx_device mlx_device_new_type(mlx_device_type type, int index) {
return mlx_device_new_type_ptr(type, index);
}
int mlx_device_free(mlx_device dev) {
return mlx_device_free_ptr(dev);
}
int mlx_device_set(mlx_device* dev, const mlx_device src) {
return mlx_device_set_ptr(dev, src);
}
int mlx_device_tostring(mlx_string* str, mlx_device dev) {
return mlx_device_tostring_ptr(str, dev);
}
bool mlx_device_equal(mlx_device lhs, mlx_device rhs) {
return mlx_device_equal_ptr(lhs, rhs);
}
int mlx_device_get_index(int* index, mlx_device dev) {
return mlx_device_get_index_ptr(index, dev);
}
int mlx_device_get_type(mlx_device_type* type, mlx_device dev) {
return mlx_device_get_type_ptr(type, dev);
}
int mlx_get_default_device(mlx_device* dev) {
return mlx_get_default_device_ptr(dev);
}
int mlx_set_default_device(mlx_device dev) {
return mlx_set_default_device_ptr(dev);
}
int mlx_distributed_all_gather(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream S) {
return mlx_distributed_all_gather_ptr(res, x, group, S);
}
int mlx_distributed_all_max(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) {
return mlx_distributed_all_max_ptr(res, x, group, s);
}
int mlx_distributed_all_min(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) {
return mlx_distributed_all_min_ptr(res, x, group, s);
}
int mlx_distributed_all_sum(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) {
return mlx_distributed_all_sum_ptr(res, x, group, s);
}
int mlx_distributed_recv(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, int src, const mlx_distributed_group group , const mlx_stream s) {
return mlx_distributed_recv_ptr(res, shape, shape_num, dtype, src, group, s);
}
int mlx_distributed_recv_like(mlx_array* res, const mlx_array x, int src, const mlx_distributed_group group , const mlx_stream s) {
return mlx_distributed_recv_like_ptr(res, x, src, group, s);
}
int mlx_distributed_send(mlx_array* res, const mlx_array x, int dst, const mlx_distributed_group group , const mlx_stream s) {
return mlx_distributed_send_ptr(res, x, dst, group, s);
}
int mlx_distributed_sum_scatter(mlx_array* res, const mlx_array x, const mlx_distributed_group group , const mlx_stream s) {
return mlx_distributed_sum_scatter_ptr(res, x, group, s);
}
int mlx_distributed_group_rank(mlx_distributed_group group) {
return mlx_distributed_group_rank_ptr(group);
}
int mlx_distributed_group_size(mlx_distributed_group group) {
return mlx_distributed_group_size_ptr(group);
}
mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key) {
return mlx_distributed_group_split_ptr(group, color, key);
}
bool mlx_distributed_is_available(void) {
return mlx_distributed_is_available_ptr();
}
mlx_distributed_group mlx_distributed_init(bool strict) {
return mlx_distributed_init_ptr(strict);
}
void mlx_set_error_handler(mlx_error_handler_func handler, void* data, void (*dtor)(void*)) {
mlx_set_error_handler_ptr(handler, data, dtor);
}
void _mlx_error(const char* file, const int line, const char* fmt, ...) {
_mlx_error_ptr(file, line, fmt);
}
int mlx_export_function(const char* file, const mlx_closure fun, const mlx_vector_array args, bool shapeless) {
return mlx_export_function_ptr(file, fun, args, shapeless);
}
int mlx_export_function_kwargs(const char* file, const mlx_closure_kwargs fun, const mlx_vector_array args, const mlx_map_string_to_array kwargs, bool shapeless) {
return mlx_export_function_kwargs_ptr(file, fun, args, kwargs, shapeless);
}
mlx_function_exporter mlx_function_exporter_new(const char* file, const mlx_closure fun, bool shapeless) {
return mlx_function_exporter_new_ptr(file, fun, shapeless);
}
int mlx_function_exporter_free(mlx_function_exporter xfunc) {
return mlx_function_exporter_free_ptr(xfunc);
}
int mlx_function_exporter_apply(const mlx_function_exporter xfunc, const mlx_vector_array args) {
return mlx_function_exporter_apply_ptr(xfunc, args);
}
int mlx_function_exporter_apply_kwargs(const mlx_function_exporter xfunc, const mlx_vector_array args, const mlx_map_string_to_array kwargs) {
return mlx_function_exporter_apply_kwargs_ptr(xfunc, args, kwargs);
}
mlx_imported_function mlx_imported_function_new(const char* file) {
return mlx_imported_function_new_ptr(file);
}
int mlx_imported_function_free(mlx_imported_function xfunc) {
return mlx_imported_function_free_ptr(xfunc);
}
int mlx_imported_function_apply(mlx_vector_array* res, const mlx_imported_function xfunc, const mlx_vector_array args) {
return mlx_imported_function_apply_ptr(res, xfunc, args);
}
int mlx_imported_function_apply_kwargs(mlx_vector_array* res, const mlx_imported_function xfunc, const mlx_vector_array args, const mlx_map_string_to_array kwargs) {
return mlx_imported_function_apply_kwargs_ptr(res, xfunc, args, kwargs);
}
mlx_fast_cuda_kernel_config mlx_fast_cuda_kernel_config_new(void) {
return mlx_fast_cuda_kernel_config_new_ptr();
}
void mlx_fast_cuda_kernel_config_free(mlx_fast_cuda_kernel_config cls) {
mlx_fast_cuda_kernel_config_free_ptr(cls);
}
int mlx_fast_cuda_kernel_config_add_output_arg(mlx_fast_cuda_kernel_config cls, const int* shape, size_t size, mlx_dtype dtype) {
return mlx_fast_cuda_kernel_config_add_output_arg_ptr(cls, shape, size, dtype);
}
int mlx_fast_cuda_kernel_config_set_grid(mlx_fast_cuda_kernel_config cls, int grid1, int grid2, int grid3) {
return mlx_fast_cuda_kernel_config_set_grid_ptr(cls, grid1, grid2, grid3);
}
int mlx_fast_cuda_kernel_config_set_thread_group(mlx_fast_cuda_kernel_config cls, int thread1, int thread2, int thread3) {
return mlx_fast_cuda_kernel_config_set_thread_group_ptr(cls, thread1, thread2, thread3);
}
int mlx_fast_cuda_kernel_config_set_init_value(mlx_fast_cuda_kernel_config cls, float value) {
return mlx_fast_cuda_kernel_config_set_init_value_ptr(cls, value);
}
int mlx_fast_cuda_kernel_config_set_verbose(mlx_fast_cuda_kernel_config cls, bool verbose) {
return mlx_fast_cuda_kernel_config_set_verbose_ptr(cls, verbose);
}
int mlx_fast_cuda_kernel_config_add_template_arg_dtype(mlx_fast_cuda_kernel_config cls, const char* name, mlx_dtype dtype) {
return mlx_fast_cuda_kernel_config_add_template_arg_dtype_ptr(cls, name, dtype);
}
int mlx_fast_cuda_kernel_config_add_template_arg_int(mlx_fast_cuda_kernel_config cls, const char* name, int value) {
return mlx_fast_cuda_kernel_config_add_template_arg_int_ptr(cls, name, value);
}
int mlx_fast_cuda_kernel_config_add_template_arg_bool(mlx_fast_cuda_kernel_config cls, const char* name, bool value) {
return mlx_fast_cuda_kernel_config_add_template_arg_bool_ptr(cls, name, value);
}
mlx_fast_cuda_kernel mlx_fast_cuda_kernel_new(const char* name, const mlx_vector_string input_names, const mlx_vector_string output_names, const char* source, const char* header, bool ensure_row_contiguous, int shared_memory) {
return mlx_fast_cuda_kernel_new_ptr(name, input_names, output_names, source, header, ensure_row_contiguous, shared_memory);
}
void mlx_fast_cuda_kernel_free(mlx_fast_cuda_kernel cls) {
mlx_fast_cuda_kernel_free_ptr(cls);
}
int mlx_fast_cuda_kernel_apply(mlx_vector_array* outputs, mlx_fast_cuda_kernel cls, const mlx_vector_array inputs, const mlx_fast_cuda_kernel_config config, const mlx_stream stream) {
return mlx_fast_cuda_kernel_apply_ptr(outputs, cls, inputs, config, stream);
}
int mlx_fast_layer_norm(mlx_array* res, const mlx_array x, const mlx_array weight , const mlx_array bias , float eps, const mlx_stream s) {
return mlx_fast_layer_norm_ptr(res, x, weight, bias, eps, s);
}
mlx_fast_metal_kernel_config mlx_fast_metal_kernel_config_new(void) {
return mlx_fast_metal_kernel_config_new_ptr();
}
void mlx_fast_metal_kernel_config_free(mlx_fast_metal_kernel_config cls) {
mlx_fast_metal_kernel_config_free_ptr(cls);
}
int mlx_fast_metal_kernel_config_add_output_arg(mlx_fast_metal_kernel_config cls, const int* shape, size_t size, mlx_dtype dtype) {
return mlx_fast_metal_kernel_config_add_output_arg_ptr(cls, shape, size, dtype);
}
int mlx_fast_metal_kernel_config_set_grid(mlx_fast_metal_kernel_config cls, int grid1, int grid2, int grid3) {
return mlx_fast_metal_kernel_config_set_grid_ptr(cls, grid1, grid2, grid3);
}
int mlx_fast_metal_kernel_config_set_thread_group(mlx_fast_metal_kernel_config cls, int thread1, int thread2, int thread3) {
return mlx_fast_metal_kernel_config_set_thread_group_ptr(cls, thread1, thread2, thread3);
}
int mlx_fast_metal_kernel_config_set_init_value(mlx_fast_metal_kernel_config cls, float value) {
return mlx_fast_metal_kernel_config_set_init_value_ptr(cls, value);
}
int mlx_fast_metal_kernel_config_set_verbose(mlx_fast_metal_kernel_config cls, bool verbose) {
return mlx_fast_metal_kernel_config_set_verbose_ptr(cls, verbose);
}
int mlx_fast_metal_kernel_config_add_template_arg_dtype(mlx_fast_metal_kernel_config cls, const char* name, mlx_dtype dtype) {
return mlx_fast_metal_kernel_config_add_template_arg_dtype_ptr(cls, name, dtype);
}
int mlx_fast_metal_kernel_config_add_template_arg_int(mlx_fast_metal_kernel_config cls, const char* name, int value) {
return mlx_fast_metal_kernel_config_add_template_arg_int_ptr(cls, name, value);
}
int mlx_fast_metal_kernel_config_add_template_arg_bool(mlx_fast_metal_kernel_config cls, const char* name, bool value) {
return mlx_fast_metal_kernel_config_add_template_arg_bool_ptr(cls, name, value);
}
mlx_fast_metal_kernel mlx_fast_metal_kernel_new(const char* name, const mlx_vector_string input_names, const mlx_vector_string output_names, const char* source, const char* header, bool ensure_row_contiguous, bool atomic_outputs) {
return mlx_fast_metal_kernel_new_ptr(name, input_names, output_names, source, header, ensure_row_contiguous, atomic_outputs);
}
void mlx_fast_metal_kernel_free(mlx_fast_metal_kernel cls) {
mlx_fast_metal_kernel_free_ptr(cls);
}
int mlx_fast_metal_kernel_apply(mlx_vector_array* outputs, mlx_fast_metal_kernel cls, const mlx_vector_array inputs, const mlx_fast_metal_kernel_config config, const mlx_stream stream) {
return mlx_fast_metal_kernel_apply_ptr(outputs, cls, inputs, config, stream);
}
int mlx_fast_rms_norm(mlx_array* res, const mlx_array x, const mlx_array weight , float eps, const mlx_stream s) {
return mlx_fast_rms_norm_ptr(res, x, weight, eps, s);
}
int mlx_fast_rope(mlx_array* res, const mlx_array x, int dims, bool traditional, mlx_optional_float base, float scale, int offset, const mlx_array freqs , const mlx_stream s) {
return mlx_fast_rope_ptr(res, x, dims, traditional, base, scale, offset, freqs, s);
}
int mlx_fast_rope_dynamic(mlx_array* res, const mlx_array x, int dims, bool traditional, mlx_optional_float base, float scale, const mlx_array offset, const mlx_array freqs , const mlx_stream s) {
return mlx_fast_rope_dynamic_ptr(res, x, dims, traditional, base, scale, offset, freqs, s);
}
int mlx_fast_scaled_dot_product_attention(mlx_array* res, const mlx_array queries, const mlx_array keys, const mlx_array values, float scale, const char* mask_mode, const mlx_array mask_arr , const mlx_array sinks , const mlx_stream s) {
return mlx_fast_scaled_dot_product_attention_ptr(res, queries, keys, values, scale, mask_mode, mask_arr, sinks, s);
}
int mlx_fft_fft(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) {
return mlx_fft_fft_ptr(res, a, n, axis, s);
}
int mlx_fft_fft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) {
return mlx_fft_fft2_ptr(res, a, n, n_num, axes, axes_num, s);
}
int mlx_fft_fftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) {
return mlx_fft_fftn_ptr(res, a, n, n_num, axes, axes_num, s);
}
int mlx_fft_fftshift(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) {
return mlx_fft_fftshift_ptr(res, a, axes, axes_num, s);
}
int mlx_fft_ifft(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) {
return mlx_fft_ifft_ptr(res, a, n, axis, s);
}
int mlx_fft_ifft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) {
return mlx_fft_ifft2_ptr(res, a, n, n_num, axes, axes_num, s);
}
int mlx_fft_ifftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) {
return mlx_fft_ifftn_ptr(res, a, n, n_num, axes, axes_num, s);
}
int mlx_fft_ifftshift(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) {
return mlx_fft_ifftshift_ptr(res, a, axes, axes_num, s);
}
int mlx_fft_irfft(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) {
return mlx_fft_irfft_ptr(res, a, n, axis, s);
}
int mlx_fft_irfft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) {
return mlx_fft_irfft2_ptr(res, a, n, n_num, axes, axes_num, s);
}
int mlx_fft_irfftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) {
return mlx_fft_irfftn_ptr(res, a, n, n_num, axes, axes_num, s);
}
int mlx_fft_rfft(mlx_array* res, const mlx_array a, int n, int axis, const mlx_stream s) {
return mlx_fft_rfft_ptr(res, a, n, axis, s);
}
int mlx_fft_rfft2(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) {
return mlx_fft_rfft2_ptr(res, a, n, n_num, axes, axes_num, s);
}
int mlx_fft_rfftn(mlx_array* res, const mlx_array a, const int* n, size_t n_num, const int* axes, size_t axes_num, const mlx_stream s) {
return mlx_fft_rfftn_ptr(res, a, n, n_num, axes, axes_num, s);
}
int mlx_load_reader(mlx_array* res, mlx_io_reader in_stream, const mlx_stream s) {
return mlx_load_reader_ptr(res, in_stream, s);
}
int mlx_load(mlx_array* res, const char* file, const mlx_stream s) {
return mlx_load_ptr(res, file, s);
}
int mlx_load_safetensors_reader(mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, mlx_io_reader in_stream, const mlx_stream s) {
return mlx_load_safetensors_reader_ptr(res_0, res_1, in_stream, s);
}
int mlx_load_safetensors(mlx_map_string_to_array* res_0, mlx_map_string_to_string* res_1, const char* file, const mlx_stream s) {
return mlx_load_safetensors_ptr(res_0, res_1, file, s);
}
int mlx_save_writer(mlx_io_writer out_stream, const mlx_array a) {
return mlx_save_writer_ptr(out_stream, a);
}
int mlx_save(const char* file, const mlx_array a) {
return mlx_save_ptr(file, a);
}
int mlx_save_safetensors_writer(mlx_io_writer in_stream, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata) {
return mlx_save_safetensors_writer_ptr(in_stream, param, metadata);
}
int mlx_save_safetensors(const char* file, const mlx_map_string_to_array param, const mlx_map_string_to_string metadata) {
return mlx_save_safetensors_ptr(file, param, metadata);
}
mlx_io_reader mlx_io_reader_new(void* desc, mlx_io_vtable vtable) {
return mlx_io_reader_new_ptr(desc, vtable);
}
int mlx_io_reader_descriptor(void** desc_, mlx_io_reader io) {
return mlx_io_reader_descriptor_ptr(desc_, io);
}
int mlx_io_reader_tostring(mlx_string* str_, mlx_io_reader io) {
return mlx_io_reader_tostring_ptr(str_, io);
}
int mlx_io_reader_free(mlx_io_reader io) {
return mlx_io_reader_free_ptr(io);
}
mlx_io_writer mlx_io_writer_new(void* desc, mlx_io_vtable vtable) {
return mlx_io_writer_new_ptr(desc, vtable);
}
int mlx_io_writer_descriptor(void** desc_, mlx_io_writer io) {
return mlx_io_writer_descriptor_ptr(desc_, io);
}
int mlx_io_writer_tostring(mlx_string* str_, mlx_io_writer io) {
return mlx_io_writer_tostring_ptr(str_, io);
}
int mlx_io_writer_free(mlx_io_writer io) {
return mlx_io_writer_free_ptr(io);
}
int mlx_linalg_cholesky(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s) {
return mlx_linalg_cholesky_ptr(res, a, upper, s);
}
int mlx_linalg_cholesky_inv(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s) {
return mlx_linalg_cholesky_inv_ptr(res, a, upper, s);
}
int mlx_linalg_cross(mlx_array* res, const mlx_array a, const mlx_array b, int axis, const mlx_stream s) {
return mlx_linalg_cross_ptr(res, a, b, axis, s);
}
int mlx_linalg_eig(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s) {
return mlx_linalg_eig_ptr(res_0, res_1, a, s);
}
int mlx_linalg_eigh(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const char* UPLO, const mlx_stream s) {
return mlx_linalg_eigh_ptr(res_0, res_1, a, UPLO, s);
}
int mlx_linalg_eigvals(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_linalg_eigvals_ptr(res, a, s);
}
int mlx_linalg_eigvalsh(mlx_array* res, const mlx_array a, const char* UPLO, const mlx_stream s) {
return mlx_linalg_eigvalsh_ptr(res, a, UPLO, s);
}
int mlx_linalg_inv(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_linalg_inv_ptr(res, a, s);
}
int mlx_linalg_lu(mlx_vector_array* res, const mlx_array a, const mlx_stream s) {
return mlx_linalg_lu_ptr(res, a, s);
}
int mlx_linalg_lu_factor(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s) {
return mlx_linalg_lu_factor_ptr(res_0, res_1, a, s);
}
int mlx_linalg_norm(mlx_array* res, const mlx_array a, double ord, const int* axis , size_t axis_num, bool keepdims, const mlx_stream s) {
return mlx_linalg_norm_ptr(res, a, ord, axis, axis_num, keepdims, s);
}
int mlx_linalg_norm_matrix(mlx_array* res, const mlx_array a, const char* ord, const int* axis , size_t axis_num, bool keepdims, const mlx_stream s) {
return mlx_linalg_norm_matrix_ptr(res, a, ord, axis, axis_num, keepdims, s);
}
int mlx_linalg_norm_l2(mlx_array* res, const mlx_array a, const int* axis , size_t axis_num, bool keepdims, const mlx_stream s) {
return mlx_linalg_norm_l2_ptr(res, a, axis, axis_num, keepdims, s);
}
int mlx_linalg_pinv(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_linalg_pinv_ptr(res, a, s);
}
int mlx_linalg_qr(mlx_array* res_0, mlx_array* res_1, const mlx_array a, const mlx_stream s) {
return mlx_linalg_qr_ptr(res_0, res_1, a, s);
}
int mlx_linalg_solve(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
return mlx_linalg_solve_ptr(res, a, b, s);
}
int mlx_linalg_solve_triangular(mlx_array* res, const mlx_array a, const mlx_array b, bool upper, const mlx_stream s) {
return mlx_linalg_solve_triangular_ptr(res, a, b, upper, s);
}
int mlx_linalg_svd(mlx_vector_array* res, const mlx_array a, bool compute_uv, const mlx_stream s) {
return mlx_linalg_svd_ptr(res, a, compute_uv, s);
}
int mlx_linalg_tri_inv(mlx_array* res, const mlx_array a, bool upper, const mlx_stream s) {
return mlx_linalg_tri_inv_ptr(res, a, upper, s);
}
mlx_map_string_to_array mlx_map_string_to_array_new(void) {
return mlx_map_string_to_array_new_ptr();
}
int mlx_map_string_to_array_set(mlx_map_string_to_array* map, const mlx_map_string_to_array src) {
return mlx_map_string_to_array_set_ptr(map, src);
}
int mlx_map_string_to_array_free(mlx_map_string_to_array map) {
return mlx_map_string_to_array_free_ptr(map);
}
int mlx_map_string_to_array_insert(mlx_map_string_to_array map, const char* key, const mlx_array value) {
return mlx_map_string_to_array_insert_ptr(map, key, value);
}
int mlx_map_string_to_array_get(mlx_array* value, const mlx_map_string_to_array map, const char* key) {
return mlx_map_string_to_array_get_ptr(value, map, key);
}
mlx_map_string_to_array_iterator mlx_map_string_to_array_iterator_new(mlx_map_string_to_array map) {
return mlx_map_string_to_array_iterator_new_ptr(map);
}
int mlx_map_string_to_array_iterator_free(mlx_map_string_to_array_iterator it) {
return mlx_map_string_to_array_iterator_free_ptr(it);
}
int mlx_map_string_to_array_iterator_next(const char** key, mlx_array* value, mlx_map_string_to_array_iterator it) {
return mlx_map_string_to_array_iterator_next_ptr(key, value, it);
}
mlx_map_string_to_string mlx_map_string_to_string_new(void) {
return mlx_map_string_to_string_new_ptr();
}
int mlx_map_string_to_string_set(mlx_map_string_to_string* map, const mlx_map_string_to_string src) {
return mlx_map_string_to_string_set_ptr(map, src);
}
int mlx_map_string_to_string_free(mlx_map_string_to_string map) {
return mlx_map_string_to_string_free_ptr(map);
}
int mlx_map_string_to_string_insert(mlx_map_string_to_string map, const char* key, const char* value) {
return mlx_map_string_to_string_insert_ptr(map, key, value);
}
int mlx_map_string_to_string_get(const char** value, const mlx_map_string_to_string map, const char* key) {
return mlx_map_string_to_string_get_ptr(value, map, key);
}
mlx_map_string_to_string_iterator mlx_map_string_to_string_iterator_new(mlx_map_string_to_string map) {
return mlx_map_string_to_string_iterator_new_ptr(map);
}
int mlx_map_string_to_string_iterator_free(mlx_map_string_to_string_iterator it) {
return mlx_map_string_to_string_iterator_free_ptr(it);
}
int mlx_map_string_to_string_iterator_next(const char** key, const char** value, mlx_map_string_to_string_iterator it) {
return mlx_map_string_to_string_iterator_next_ptr(key, value, it);
}
int mlx_clear_cache(void) {
return mlx_clear_cache_ptr();
}
int mlx_get_active_memory(size_t* res) {
return mlx_get_active_memory_ptr(res);
}
int mlx_get_cache_memory(size_t* res) {
return mlx_get_cache_memory_ptr(res);
}
int mlx_get_memory_limit(size_t* res) {
return mlx_get_memory_limit_ptr(res);
}
int mlx_get_peak_memory(size_t* res) {
return mlx_get_peak_memory_ptr(res);
}
int mlx_reset_peak_memory(void) {
return mlx_reset_peak_memory_ptr();
}
int mlx_set_cache_limit(size_t* res, size_t limit) {
return mlx_set_cache_limit_ptr(res, limit);
}
int mlx_set_memory_limit(size_t* res, size_t limit) {
return mlx_set_memory_limit_ptr(res, limit);
}
int mlx_set_wired_limit(size_t* res, size_t limit) {
return mlx_set_wired_limit_ptr(res, limit);
}
mlx_metal_device_info_t mlx_metal_device_info(void) {
return mlx_metal_device_info_ptr();
}
int mlx_metal_is_available(bool* res) {
return mlx_metal_is_available_ptr(res);
}
int mlx_metal_start_capture(const char* path) {
return mlx_metal_start_capture_ptr(path);
}
int mlx_metal_stop_capture(void) {
return mlx_metal_stop_capture_ptr();
}
int mlx_abs(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_abs_ptr(res, a, s);
}
int mlx_add(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
return mlx_add_ptr(res, a, b, s);
}
int mlx_addmm(mlx_array* res, const mlx_array c, const mlx_array a, const mlx_array b, float alpha, float beta, const mlx_stream s) {
return mlx_addmm_ptr(res, c, a, b, alpha, beta, s);
}
int mlx_all_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) {
return mlx_all_axes_ptr(res, a, axes, axes_num, keepdims, s);
}
int mlx_all_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) {
return mlx_all_axis_ptr(res, a, axis, keepdims, s);
}
int mlx_all(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) {
return mlx_all_ptr(res, a, keepdims, s);
}
int mlx_allclose(mlx_array* res, const mlx_array a, const mlx_array b, double rtol, double atol, bool equal_nan, const mlx_stream s) {
return mlx_allclose_ptr(res, a, b, rtol, atol, equal_nan, s);
}
int mlx_any_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) {
return mlx_any_axes_ptr(res, a, axes, axes_num, keepdims, s);
}
int mlx_any_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) {
return mlx_any_axis_ptr(res, a, axis, keepdims, s);
}
int mlx_any(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) {
return mlx_any_ptr(res, a, keepdims, s);
}
int mlx_arange(mlx_array* res, double start, double stop, double step, mlx_dtype dtype, const mlx_stream s) {
return mlx_arange_ptr(res, start, stop, step, dtype, s);
}
int mlx_arccos(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_arccos_ptr(res, a, s);
}
int mlx_arccosh(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_arccosh_ptr(res, a, s);
}
int mlx_arcsin(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_arcsin_ptr(res, a, s);
}
int mlx_arcsinh(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_arcsinh_ptr(res, a, s);
}
int mlx_arctan(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_arctan_ptr(res, a, s);
}
int mlx_arctan2(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
return mlx_arctan2_ptr(res, a, b, s);
}
int mlx_arctanh(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_arctanh_ptr(res, a, s);
}
int mlx_argmax_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) {
return mlx_argmax_axis_ptr(res, a, axis, keepdims, s);
}
int mlx_argmax(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) {
return mlx_argmax_ptr(res, a, keepdims, s);
}
int mlx_argmin_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) {
return mlx_argmin_axis_ptr(res, a, axis, keepdims, s);
}
int mlx_argmin(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) {
return mlx_argmin_ptr(res, a, keepdims, s);
}
int mlx_argpartition_axis(mlx_array* res, const mlx_array a, int kth, int axis, const mlx_stream s) {
return mlx_argpartition_axis_ptr(res, a, kth, axis, s);
}
int mlx_argpartition(mlx_array* res, const mlx_array a, int kth, const mlx_stream s) {
return mlx_argpartition_ptr(res, a, kth, s);
}
int mlx_argsort_axis(mlx_array* res, const mlx_array a, int axis, const mlx_stream s) {
return mlx_argsort_axis_ptr(res, a, axis, s);
}
int mlx_argsort(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_argsort_ptr(res, a, s);
}
int mlx_array_equal(mlx_array* res, const mlx_array a, const mlx_array b, bool equal_nan, const mlx_stream s) {
return mlx_array_equal_ptr(res, a, b, equal_nan, s);
}
int mlx_as_strided(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const int64_t* strides, size_t strides_num, size_t offset, const mlx_stream s) {
return mlx_as_strided_ptr(res, a, shape, shape_num, strides, strides_num, offset, s);
}
int mlx_astype(mlx_array* res, const mlx_array a, mlx_dtype dtype, const mlx_stream s) {
return mlx_astype_ptr(res, a, dtype, s);
}
int mlx_atleast_1d(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_atleast_1d_ptr(res, a, s);
}
int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_atleast_2d_ptr(res, a, s);
}
int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_atleast_3d_ptr(res, a, s);
}
int mlx_bitwise_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
return mlx_bitwise_and_ptr(res, a, b, s);
}
int mlx_bitwise_invert(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_bitwise_invert_ptr(res, a, s);
}
int mlx_bitwise_or(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
return mlx_bitwise_or_ptr(res, a, b, s);
}
int mlx_bitwise_xor(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
return mlx_bitwise_xor_ptr(res, a, b, s);
}
int mlx_block_masked_mm(mlx_array* res, const mlx_array a, const mlx_array b, int block_size, const mlx_array mask_out , const mlx_array mask_lhs , const mlx_array mask_rhs , const mlx_stream s) {
return mlx_block_masked_mm_ptr(res, a, b, block_size, mask_out, mask_lhs, mask_rhs, s);
}
int mlx_broadcast_arrays(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_stream s) {
return mlx_broadcast_arrays_ptr(res, inputs, s);
}
int mlx_broadcast_to(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s) {
return mlx_broadcast_to_ptr(res, a, shape, shape_num, s);
}
int mlx_ceil(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_ceil_ptr(res, a, s);
}
int mlx_clip(mlx_array* res, const mlx_array a, const mlx_array a_min , const mlx_array a_max , const mlx_stream s) {
return mlx_clip_ptr(res, a, a_min, a_max, s);
}
int mlx_concatenate_axis(mlx_array* res, const mlx_vector_array arrays, int axis, const mlx_stream s) {
return mlx_concatenate_axis_ptr(res, arrays, axis, s);
}
int mlx_concatenate(mlx_array* res, const mlx_vector_array arrays, const mlx_stream s) {
return mlx_concatenate_ptr(res, arrays, s);
}
int mlx_conjugate(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_conjugate_ptr(res, a, s);
}
int mlx_contiguous(mlx_array* res, const mlx_array a, bool allow_col_major, const mlx_stream s) {
return mlx_contiguous_ptr(res, a, allow_col_major, s);
}
int mlx_conv1d(mlx_array* res, const mlx_array input, const mlx_array weight, int stride, int padding, int dilation, int groups, const mlx_stream s) {
return mlx_conv1d_ptr(res, input, weight, stride, padding, dilation, groups, s);
}
int mlx_conv2d(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int padding_0, int padding_1, int dilation_0, int dilation_1, int groups, const mlx_stream s) {
return mlx_conv2d_ptr(res, input, weight, stride_0, stride_1, padding_0, padding_1, dilation_0, dilation_1, groups, s);
}
int mlx_conv3d(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int stride_2, int padding_0, int padding_1, int padding_2, int dilation_0, int dilation_1, int dilation_2, int groups, const mlx_stream s) {
return mlx_conv3d_ptr(res, input, weight, stride_0, stride_1, stride_2, padding_0, padding_1, padding_2, dilation_0, dilation_1, dilation_2, groups, s);
}
int mlx_conv_general(mlx_array* res, const mlx_array input, const mlx_array weight, const int* stride, size_t stride_num, const int* padding_lo, size_t padding_lo_num, const int* padding_hi, size_t padding_hi_num, const int* kernel_dilation, size_t kernel_dilation_num, const int* input_dilation, size_t input_dilation_num, int groups, bool flip, const mlx_stream s) {
return mlx_conv_general_ptr(res, input, weight, stride, stride_num, padding_lo, padding_lo_num, padding_hi, padding_hi_num, kernel_dilation, kernel_dilation_num, input_dilation, input_dilation_num, groups, flip, s);
}
int mlx_conv_transpose1d(mlx_array* res, const mlx_array input, const mlx_array weight, int stride, int padding, int dilation, int output_padding, int groups, const mlx_stream s) {
return mlx_conv_transpose1d_ptr(res, input, weight, stride, padding, dilation, output_padding, groups, s);
}
int mlx_conv_transpose2d(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int padding_0, int padding_1, int dilation_0, int dilation_1, int output_padding_0, int output_padding_1, int groups, const mlx_stream s) {
return mlx_conv_transpose2d_ptr(res, input, weight, stride_0, stride_1, padding_0, padding_1, dilation_0, dilation_1, output_padding_0, output_padding_1, groups, s);
}
int mlx_conv_transpose3d(mlx_array* res, const mlx_array input, const mlx_array weight, int stride_0, int stride_1, int stride_2, int padding_0, int padding_1, int padding_2, int dilation_0, int dilation_1, int dilation_2, int output_padding_0, int output_padding_1, int output_padding_2, int groups, const mlx_stream s) {
return mlx_conv_transpose3d_ptr(res, input, weight, stride_0, stride_1, stride_2, padding_0, padding_1, padding_2, dilation_0, dilation_1, dilation_2, output_padding_0, output_padding_1, output_padding_2, groups, s);
}
int mlx_copy(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_copy_ptr(res, a, s);
}
int mlx_cos(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_cos_ptr(res, a, s);
}
int mlx_cosh(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_cosh_ptr(res, a, s);
}
int mlx_cummax(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) {
return mlx_cummax_ptr(res, a, axis, reverse, inclusive, s);
}
int mlx_cummin(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) {
return mlx_cummin_ptr(res, a, axis, reverse, inclusive, s);
}
int mlx_cumprod(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) {
return mlx_cumprod_ptr(res, a, axis, reverse, inclusive, s);
}
int mlx_cumsum(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) {
return mlx_cumsum_ptr(res, a, axis, reverse, inclusive, s);
}
int mlx_degrees(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_degrees_ptr(res, a, s);
}
int mlx_depends(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array dependencies) {
return mlx_depends_ptr(res, inputs, dependencies);
}
int mlx_dequantize(mlx_array* res, const mlx_array w, const mlx_array scales, const mlx_array biases , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, mlx_optional_dtype dtype, const mlx_stream s) {
return mlx_dequantize_ptr(res, w, scales, biases, group_size, bits, mode, dtype, s);
}
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) {
return mlx_diag_ptr(res, a, k, s);
}
int mlx_diagonal(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, const mlx_stream s) {
return mlx_diagonal_ptr(res, a, offset, axis1, axis2, s);
}
int mlx_divide(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
return mlx_divide_ptr(res, a, b, s);
}
int mlx_divmod(mlx_vector_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
return mlx_divmod_ptr(res, a, b, s);
}
int mlx_einsum(mlx_array* res, const char* subscripts, const mlx_vector_array operands, const mlx_stream s) {
return mlx_einsum_ptr(res, subscripts, operands, s);
}
int mlx_equal(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
return mlx_equal_ptr(res, a, b, s);
}
int mlx_erf(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_erf_ptr(res, a, s);
}
int mlx_erfinv(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_erfinv_ptr(res, a, s);
}
int mlx_exp(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_exp_ptr(res, a, s);
}
int mlx_expand_dims_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) {
return mlx_expand_dims_axes_ptr(res, a, axes, axes_num, s);
}
int mlx_expand_dims(mlx_array* res, const mlx_array a, int axis, const mlx_stream s) {
return mlx_expand_dims_ptr(res, a, axis, s);
}
int mlx_expm1(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_expm1_ptr(res, a, s);
}
int mlx_eye(mlx_array* res, int n, int m, int k, mlx_dtype dtype, const mlx_stream s) {
return mlx_eye_ptr(res, n, m, k, dtype, s);
}
int mlx_flatten(mlx_array* res, const mlx_array a, int start_axis, int end_axis, const mlx_stream s) {
return mlx_flatten_ptr(res, a, start_axis, end_axis, s);
}
int mlx_floor(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_floor_ptr(res, a, s);
}
int mlx_floor_divide(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
return mlx_floor_divide_ptr(res, a, b, s);
}
int mlx_from_fp8(mlx_array* res, const mlx_array x, mlx_dtype dtype, const mlx_stream s) {
return mlx_from_fp8_ptr(res, x, dtype, s);
}
int mlx_full(mlx_array* res, const int* shape, size_t shape_num, const mlx_array vals, mlx_dtype dtype, const mlx_stream s) {
return mlx_full_ptr(res, shape, shape_num, vals, dtype, s);
}
int mlx_full_like(mlx_array* res, const mlx_array a, const mlx_array vals, mlx_dtype dtype, const mlx_stream s) {
return mlx_full_like_ptr(res, a, vals, dtype, s);
}
int mlx_gather(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const int* axes, size_t axes_num, const int* slice_sizes, size_t slice_sizes_num, const mlx_stream s) {
return mlx_gather_ptr(res, a, indices, axes, axes_num, slice_sizes, slice_sizes_num, s);
}
int mlx_gather_single(mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const int* slice_sizes, size_t slice_sizes_num, const mlx_stream s) {
return mlx_gather_single_ptr(res, a, indices, axis, slice_sizes, slice_sizes_num, s);
}
int mlx_gather_mm(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_array lhs_indices , const mlx_array rhs_indices , bool sorted_indices, const mlx_stream s) {
return mlx_gather_mm_ptr(res, a, b, lhs_indices, rhs_indices, sorted_indices, s);
}
int mlx_gather_qmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , const mlx_array lhs_indices , const mlx_array rhs_indices , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, bool sorted_indices, const mlx_stream s) {
return mlx_gather_qmm_ptr(res, x, w, scales, biases, lhs_indices, rhs_indices, transpose, group_size, bits, mode, sorted_indices, s);
}
int mlx_greater(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
return mlx_greater_ptr(res, a, b, s);
}
int mlx_greater_equal(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
return mlx_greater_equal_ptr(res, a, b, s);
}
int mlx_hadamard_transform(mlx_array* res, const mlx_array a, mlx_optional_float scale, const mlx_stream s) {
return mlx_hadamard_transform_ptr(res, a, scale, s);
}
int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) {
return mlx_identity_ptr(res, n, dtype, s);
}
int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_imag_ptr(res, a, s);
}
int mlx_inner(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
return mlx_inner_ptr(res, a, b, s);
}
int mlx_isclose(mlx_array* res, const mlx_array a, const mlx_array b, double rtol, double atol, bool equal_nan, const mlx_stream s) {
return mlx_isclose_ptr(res, a, b, rtol, atol, equal_nan, s);
}
int mlx_isfinite(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_isfinite_ptr(res, a, s);
}
int mlx_isinf(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_isinf_ptr(res, a, s);
}
int mlx_isnan(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_isnan_ptr(res, a, s);
}
int mlx_isneginf(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_isneginf_ptr(res, a, s);
}
int mlx_isposinf(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_isposinf_ptr(res, a, s);
}
int mlx_kron(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
return mlx_kron_ptr(res, a, b, s);
}
int mlx_left_shift(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
return mlx_left_shift_ptr(res, a, b, s);
}
int mlx_less(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
return mlx_less_ptr(res, a, b, s);
}
int mlx_less_equal(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
return mlx_less_equal_ptr(res, a, b, s);
}
int mlx_linspace(mlx_array* res, double start, double stop, int num, mlx_dtype dtype, const mlx_stream s) {
return mlx_linspace_ptr(res, start, stop, num, dtype, s);
}
int mlx_log(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_log_ptr(res, a, s);
}
int mlx_log10(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_log10_ptr(res, a, s);
}
int mlx_log1p(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_log1p_ptr(res, a, s);
}
int mlx_log2(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_log2_ptr(res, a, s);
}
int mlx_logaddexp(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
return mlx_logaddexp_ptr(res, a, b, s);
}
int mlx_logcumsumexp(mlx_array* res, const mlx_array a, int axis, bool reverse, bool inclusive, const mlx_stream s) {
return mlx_logcumsumexp_ptr(res, a, axis, reverse, inclusive, s);
}
int mlx_logical_and(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
return mlx_logical_and_ptr(res, a, b, s);
}
int mlx_logical_not(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_logical_not_ptr(res, a, s);
}
int mlx_logical_or(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
return mlx_logical_or_ptr(res, a, b, s);
}
int mlx_logsumexp_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) {
return mlx_logsumexp_axes_ptr(res, a, axes, axes_num, keepdims, s);
}
int mlx_logsumexp_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) {
return mlx_logsumexp_axis_ptr(res, a, axis, keepdims, s);
}
int mlx_logsumexp(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) {
return mlx_logsumexp_ptr(res, a, keepdims, s);
}
int mlx_masked_scatter(mlx_array* res, const mlx_array a, const mlx_array mask, const mlx_array src, const mlx_stream s) {
return mlx_masked_scatter_ptr(res, a, mask, src, s);
}
int mlx_matmul(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
return mlx_matmul_ptr(res, a, b, s);
}
int mlx_max_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) {
return mlx_max_axes_ptr(res, a, axes, axes_num, keepdims, s);
}
int mlx_max_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) {
return mlx_max_axis_ptr(res, a, axis, keepdims, s);
}
int mlx_max(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) {
return mlx_max_ptr(res, a, keepdims, s);
}
int mlx_maximum(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
return mlx_maximum_ptr(res, a, b, s);
}
int mlx_mean_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) {
return mlx_mean_axes_ptr(res, a, axes, axes_num, keepdims, s);
}
int mlx_mean_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) {
return mlx_mean_axis_ptr(res, a, axis, keepdims, s);
}
int mlx_mean(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) {
return mlx_mean_ptr(res, a, keepdims, s);
}
int mlx_median(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) {
return mlx_median_ptr(res, a, axes, axes_num, keepdims, s);
}
int mlx_meshgrid(mlx_vector_array* res, const mlx_vector_array arrays, bool sparse, const char* indexing, const mlx_stream s) {
return mlx_meshgrid_ptr(res, arrays, sparse, indexing, s);
}
int mlx_min_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) {
return mlx_min_axes_ptr(res, a, axes, axes_num, keepdims, s);
}
int mlx_min_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) {
return mlx_min_axis_ptr(res, a, axis, keepdims, s);
}
int mlx_min(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) {
return mlx_min_ptr(res, a, keepdims, s);
}
int mlx_minimum(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
return mlx_minimum_ptr(res, a, b, s);
}
int mlx_moveaxis(mlx_array* res, const mlx_array a, int source, int destination, const mlx_stream s) {
return mlx_moveaxis_ptr(res, a, source, destination, s);
}
int mlx_multiply(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
return mlx_multiply_ptr(res, a, b, s);
}
int mlx_nan_to_num(mlx_array* res, const mlx_array a, float nan, mlx_optional_float posinf, mlx_optional_float neginf, const mlx_stream s) {
return mlx_nan_to_num_ptr(res, a, nan, posinf, neginf, s);
}
int mlx_negative(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_negative_ptr(res, a, s);
}
int mlx_not_equal(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
return mlx_not_equal_ptr(res, a, b, s);
}
int mlx_number_of_elements(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool inverted, mlx_dtype dtype, const mlx_stream s) {
return mlx_number_of_elements_ptr(res, a, axes, axes_num, inverted, dtype, s);
}
int mlx_ones(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_stream s) {
return mlx_ones_ptr(res, shape, shape_num, dtype, s);
}
int mlx_ones_like(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_ones_like_ptr(res, a, s);
}
int mlx_outer(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
return mlx_outer_ptr(res, a, b, s);
}
int mlx_pad(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const int* low_pad_size, size_t low_pad_size_num, const int* high_pad_size, size_t high_pad_size_num, const mlx_array pad_value, const char* mode, const mlx_stream s) {
return mlx_pad_ptr(res, a, axes, axes_num, low_pad_size, low_pad_size_num, high_pad_size, high_pad_size_num, pad_value, mode, s);
}
int mlx_pad_symmetric(mlx_array* res, const mlx_array a, int pad_width, const mlx_array pad_value, const char* mode, const mlx_stream s) {
return mlx_pad_symmetric_ptr(res, a, pad_width, pad_value, mode, s);
}
int mlx_partition_axis(mlx_array* res, const mlx_array a, int kth, int axis, const mlx_stream s) {
return mlx_partition_axis_ptr(res, a, kth, axis, s);
}
int mlx_partition(mlx_array* res, const mlx_array a, int kth, const mlx_stream s) {
return mlx_partition_ptr(res, a, kth, s);
}
int mlx_power(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
return mlx_power_ptr(res, a, b, s);
}
int mlx_prod_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) {
return mlx_prod_axes_ptr(res, a, axes, axes_num, keepdims, s);
}
int mlx_prod_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) {
return mlx_prod_axis_ptr(res, a, axis, keepdims, s);
}
int mlx_prod(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) {
return mlx_prod_ptr(res, a, keepdims, s);
}
int mlx_put_along_axis(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s) {
return mlx_put_along_axis_ptr(res, a, indices, values, axis, s);
}
int mlx_qqmm(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array w_scales , mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) {
return mlx_qqmm_ptr(res, x, w, w_scales, group_size, bits, mode, s);
}
int mlx_quantize(mlx_vector_array* res, const mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) {
return mlx_quantize_ptr(res, w, group_size, bits, mode, s);
}
int mlx_quantized_matmul(mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, const mlx_array biases , bool transpose, mlx_optional_int group_size, mlx_optional_int bits, const char* mode, const mlx_stream s) {
return mlx_quantized_matmul_ptr(res, x, w, scales, biases, transpose, group_size, bits, mode, s);
}
int mlx_radians(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_radians_ptr(res, a, s);
}
int mlx_real(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_real_ptr(res, a, s);
}
int mlx_reciprocal(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_reciprocal_ptr(res, a, s);
}
int mlx_remainder(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
return mlx_remainder_ptr(res, a, b, s);
}
int mlx_repeat_axis(mlx_array* res, const mlx_array arr, int repeats, int axis, const mlx_stream s) {
return mlx_repeat_axis_ptr(res, arr, repeats, axis, s);
}
int mlx_repeat(mlx_array* res, const mlx_array arr, int repeats, const mlx_stream s) {
return mlx_repeat_ptr(res, arr, repeats, s);
}
int mlx_reshape(mlx_array* res, const mlx_array a, const int* shape, size_t shape_num, const mlx_stream s) {
return mlx_reshape_ptr(res, a, shape, shape_num, s);
}
int mlx_right_shift(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
return mlx_right_shift_ptr(res, a, b, s);
}
int mlx_roll_axis(mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, int axis, const mlx_stream s) {
return mlx_roll_axis_ptr(res, a, shift, shift_num, axis, s);
}
int mlx_roll_axes(mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, const int* axes, size_t axes_num, const mlx_stream s) {
return mlx_roll_axes_ptr(res, a, shift, shift_num, axes, axes_num, s);
}
int mlx_roll(mlx_array* res, const mlx_array a, const int* shift, size_t shift_num, const mlx_stream s) {
return mlx_roll_ptr(res, a, shift, shift_num, s);
}
int mlx_round(mlx_array* res, const mlx_array a, int decimals, const mlx_stream s) {
return mlx_round_ptr(res, a, decimals, s);
}
int mlx_rsqrt(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_rsqrt_ptr(res, a, s);
}
int mlx_scatter(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) {
return mlx_scatter_ptr(res, a, indices, updates, axes, axes_num, s);
}
int mlx_scatter_single(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) {
return mlx_scatter_single_ptr(res, a, indices, updates, axis, s);
}
int mlx_scatter_add(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) {
return mlx_scatter_add_ptr(res, a, indices, updates, axes, axes_num, s);
}
int mlx_scatter_add_single(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) {
return mlx_scatter_add_single_ptr(res, a, indices, updates, axis, s);
}
int mlx_scatter_add_axis(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array values, int axis, const mlx_stream s) {
return mlx_scatter_add_axis_ptr(res, a, indices, values, axis, s);
}
int mlx_scatter_max(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) {
return mlx_scatter_max_ptr(res, a, indices, updates, axes, axes_num, s);
}
int mlx_scatter_max_single(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) {
return mlx_scatter_max_single_ptr(res, a, indices, updates, axis, s);
}
int mlx_scatter_min(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) {
return mlx_scatter_min_ptr(res, a, indices, updates, axes, axes_num, s);
}
int mlx_scatter_min_single(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) {
return mlx_scatter_min_single_ptr(res, a, indices, updates, axis, s);
}
int mlx_scatter_prod(mlx_array* res, const mlx_array a, const mlx_vector_array indices, const mlx_array updates, const int* axes, size_t axes_num, const mlx_stream s) {
return mlx_scatter_prod_ptr(res, a, indices, updates, axes, axes_num, s);
}
int mlx_scatter_prod_single(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_array updates, int axis, const mlx_stream s) {
return mlx_scatter_prod_single_ptr(res, a, indices, updates, axis, s);
}
int mlx_segmented_mm(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_array segments, const mlx_stream s) {
return mlx_segmented_mm_ptr(res, a, b, segments, s);
}
int mlx_sigmoid(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_sigmoid_ptr(res, a, s);
}
int mlx_sign(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_sign_ptr(res, a, s);
}
int mlx_sin(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_sin_ptr(res, a, s);
}
int mlx_sinh(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_sinh_ptr(res, a, s);
}
int mlx_slice(mlx_array* res, const mlx_array a, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s) {
return mlx_slice_ptr(res, a, start, start_num, stop, stop_num, strides, strides_num, s);
}
int mlx_slice_dynamic(mlx_array* res, const mlx_array a, const mlx_array start, const int* axes, size_t axes_num, const int* slice_size, size_t slice_size_num, const mlx_stream s) {
return mlx_slice_dynamic_ptr(res, a, start, axes, axes_num, slice_size, slice_size_num, s);
}
int mlx_slice_update(mlx_array* res, const mlx_array src, const mlx_array update, const int* start, size_t start_num, const int* stop, size_t stop_num, const int* strides, size_t strides_num, const mlx_stream s) {
return mlx_slice_update_ptr(res, src, update, start, start_num, stop, stop_num, strides, strides_num, s);
}
int mlx_slice_update_dynamic(mlx_array* res, const mlx_array src, const mlx_array update, const mlx_array start, const int* axes, size_t axes_num, const mlx_stream s) {
return mlx_slice_update_dynamic_ptr(res, src, update, start, axes, axes_num, s);
}
int mlx_softmax_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool precise, const mlx_stream s) {
return mlx_softmax_axes_ptr(res, a, axes, axes_num, precise, s);
}
int mlx_softmax_axis(mlx_array* res, const mlx_array a, int axis, bool precise, const mlx_stream s) {
return mlx_softmax_axis_ptr(res, a, axis, precise, s);
}
int mlx_softmax(mlx_array* res, const mlx_array a, bool precise, const mlx_stream s) {
return mlx_softmax_ptr(res, a, precise, s);
}
int mlx_sort_axis(mlx_array* res, const mlx_array a, int axis, const mlx_stream s) {
return mlx_sort_axis_ptr(res, a, axis, s);
}
int mlx_sort(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_sort_ptr(res, a, s);
}
int mlx_split(mlx_vector_array* res, const mlx_array a, int num_splits, int axis, const mlx_stream s) {
return mlx_split_ptr(res, a, num_splits, axis, s);
}
int mlx_split_sections(mlx_vector_array* res, const mlx_array a, const int* indices, size_t indices_num, int axis, const mlx_stream s) {
return mlx_split_sections_ptr(res, a, indices, indices_num, axis, s);
}
int mlx_sqrt(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_sqrt_ptr(res, a, s);
}
int mlx_square(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_square_ptr(res, a, s);
}
int mlx_squeeze_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) {
return mlx_squeeze_axes_ptr(res, a, axes, axes_num, s);
}
int mlx_squeeze_axis(mlx_array* res, const mlx_array a, int axis, const mlx_stream s) {
return mlx_squeeze_axis_ptr(res, a, axis, s);
}
int mlx_squeeze(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_squeeze_ptr(res, a, s);
}
int mlx_stack_axis(mlx_array* res, const mlx_vector_array arrays, int axis, const mlx_stream s) {
return mlx_stack_axis_ptr(res, arrays, axis, s);
}
int mlx_stack(mlx_array* res, const mlx_vector_array arrays, const mlx_stream s) {
return mlx_stack_ptr(res, arrays, s);
}
int mlx_std_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, int ddof, const mlx_stream s) {
return mlx_std_axes_ptr(res, a, axes, axes_num, keepdims, ddof, s);
}
int mlx_std_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, int ddof, const mlx_stream s) {
return mlx_std_axis_ptr(res, a, axis, keepdims, ddof, s);
}
int mlx_std(mlx_array* res, const mlx_array a, bool keepdims, int ddof, const mlx_stream s) {
return mlx_std_ptr(res, a, keepdims, ddof, s);
}
int mlx_stop_gradient(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_stop_gradient_ptr(res, a, s);
}
int mlx_subtract(mlx_array* res, const mlx_array a, const mlx_array b, const mlx_stream s) {
return mlx_subtract_ptr(res, a, b, s);
}
int mlx_sum_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, const mlx_stream s) {
return mlx_sum_axes_ptr(res, a, axes, axes_num, keepdims, s);
}
int mlx_sum_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, const mlx_stream s) {
return mlx_sum_axis_ptr(res, a, axis, keepdims, s);
}
int mlx_sum(mlx_array* res, const mlx_array a, bool keepdims, const mlx_stream s) {
return mlx_sum_ptr(res, a, keepdims, s);
}
int mlx_swapaxes(mlx_array* res, const mlx_array a, int axis1, int axis2, const mlx_stream s) {
return mlx_swapaxes_ptr(res, a, axis1, axis2, s);
}
int mlx_take_axis(mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const mlx_stream s) {
return mlx_take_axis_ptr(res, a, indices, axis, s);
}
int mlx_take(mlx_array* res, const mlx_array a, const mlx_array indices, const mlx_stream s) {
return mlx_take_ptr(res, a, indices, s);
}
int mlx_take_along_axis(mlx_array* res, const mlx_array a, const mlx_array indices, int axis, const mlx_stream s) {
return mlx_take_along_axis_ptr(res, a, indices, axis, s);
}
int mlx_tan(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_tan_ptr(res, a, s);
}
int mlx_tanh(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_tanh_ptr(res, a, s);
}
int mlx_tensordot(mlx_array* res, const mlx_array a, const mlx_array b, const int* axes_a, size_t axes_a_num, const int* axes_b, size_t axes_b_num, const mlx_stream s) {
return mlx_tensordot_ptr(res, a, b, axes_a, axes_a_num, axes_b, axes_b_num, s);
}
int mlx_tensordot_axis(mlx_array* res, const mlx_array a, const mlx_array b, int axis, const mlx_stream s) {
return mlx_tensordot_axis_ptr(res, a, b, axis, s);
}
int mlx_tile(mlx_array* res, const mlx_array arr, const int* reps, size_t reps_num, const mlx_stream s) {
return mlx_tile_ptr(res, arr, reps, reps_num, s);
}
int mlx_to_fp8(mlx_array* res, const mlx_array x, const mlx_stream s) {
return mlx_to_fp8_ptr(res, x, s);
}
int mlx_topk_axis(mlx_array* res, const mlx_array a, int k, int axis, const mlx_stream s) {
return mlx_topk_axis_ptr(res, a, k, axis, s);
}
int mlx_topk(mlx_array* res, const mlx_array a, int k, const mlx_stream s) {
return mlx_topk_ptr(res, a, k, s);
}
int mlx_trace(mlx_array* res, const mlx_array a, int offset, int axis1, int axis2, mlx_dtype dtype, const mlx_stream s) {
return mlx_trace_ptr(res, a, offset, axis1, axis2, dtype, s);
}
int mlx_transpose_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, const mlx_stream s) {
return mlx_transpose_axes_ptr(res, a, axes, axes_num, s);
}
int mlx_transpose(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_transpose_ptr(res, a, s);
}
int mlx_tri(mlx_array* res, int n, int m, int k, mlx_dtype type, const mlx_stream s) {
return mlx_tri_ptr(res, n, m, k, type, s);
}
int mlx_tril(mlx_array* res, const mlx_array x, int k, const mlx_stream s) {
return mlx_tril_ptr(res, x, k, s);
}
int mlx_triu(mlx_array* res, const mlx_array x, int k, const mlx_stream s) {
return mlx_triu_ptr(res, x, k, s);
}
int mlx_unflatten(mlx_array* res, const mlx_array a, int axis, const int* shape, size_t shape_num, const mlx_stream s) {
return mlx_unflatten_ptr(res, a, axis, shape, shape_num, s);
}
int mlx_var_axes(mlx_array* res, const mlx_array a, const int* axes, size_t axes_num, bool keepdims, int ddof, const mlx_stream s) {
return mlx_var_axes_ptr(res, a, axes, axes_num, keepdims, ddof, s);
}
int mlx_var_axis(mlx_array* res, const mlx_array a, int axis, bool keepdims, int ddof, const mlx_stream s) {
return mlx_var_axis_ptr(res, a, axis, keepdims, ddof, s);
}
int mlx_var(mlx_array* res, const mlx_array a, bool keepdims, int ddof, const mlx_stream s) {
return mlx_var_ptr(res, a, keepdims, ddof, s);
}
int mlx_view(mlx_array* res, const mlx_array a, mlx_dtype dtype, const mlx_stream s) {
return mlx_view_ptr(res, a, dtype, s);
}
int mlx_where(mlx_array* res, const mlx_array condition, const mlx_array x, const mlx_array y, const mlx_stream s) {
return mlx_where_ptr(res, condition, x, y, s);
}
int mlx_zeros(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_stream s) {
return mlx_zeros_ptr(res, shape, shape_num, dtype, s);
}
int mlx_zeros_like(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_zeros_like_ptr(res, a, s);
}
int mlx_random_bernoulli(mlx_array* res, const mlx_array p, const int* shape, size_t shape_num, const mlx_array key , const mlx_stream s) {
return mlx_random_bernoulli_ptr(res, p, shape, shape_num, key, s);
}
int mlx_random_bits(mlx_array* res, const int* shape, size_t shape_num, int width, const mlx_array key , const mlx_stream s) {
return mlx_random_bits_ptr(res, shape, shape_num, width, key, s);
}
int mlx_random_categorical_shape(mlx_array* res, const mlx_array logits, int axis, const int* shape, size_t shape_num, const mlx_array key , const mlx_stream s) {
return mlx_random_categorical_shape_ptr(res, logits, axis, shape, shape_num, key, s);
}
int mlx_random_categorical_num_samples(mlx_array* res, const mlx_array logits_, int axis, int num_samples, const mlx_array key , const mlx_stream s) {
return mlx_random_categorical_num_samples_ptr(res, logits_, axis, num_samples, key, s);
}
int mlx_random_categorical(mlx_array* res, const mlx_array logits, int axis, const mlx_array key , const mlx_stream s) {
return mlx_random_categorical_ptr(res, logits, axis, key, s);
}
int mlx_random_gumbel(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) {
return mlx_random_gumbel_ptr(res, shape, shape_num, dtype, key, s);
}
int mlx_random_key(mlx_array* res, uint64_t seed) {
return mlx_random_key_ptr(res, seed);
}
int mlx_random_laplace(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, float loc, float scale, const mlx_array key , const mlx_stream s) {
return mlx_random_laplace_ptr(res, shape, shape_num, dtype, loc, scale, key, s);
}
int mlx_random_multivariate_normal(mlx_array* res, const mlx_array mean, const mlx_array cov, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) {
return mlx_random_multivariate_normal_ptr(res, mean, cov, shape, shape_num, dtype, key, s);
}
int mlx_random_normal_broadcast(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array loc , const mlx_array scale , const mlx_array key , const mlx_stream s) {
return mlx_random_normal_broadcast_ptr(res, shape, shape_num, dtype, loc, scale, key, s);
}
int mlx_random_normal(mlx_array* res, const int* shape, size_t shape_num, mlx_dtype dtype, float loc, float scale, const mlx_array key , const mlx_stream s) {
return mlx_random_normal_ptr(res, shape, shape_num, dtype, loc, scale, key, s);
}
int mlx_random_permutation(mlx_array* res, const mlx_array x, int axis, const mlx_array key , const mlx_stream s) {
return mlx_random_permutation_ptr(res, x, axis, key, s);
}
int mlx_random_permutation_arange(mlx_array* res, int x, const mlx_array key , const mlx_stream s) {
return mlx_random_permutation_arange_ptr(res, x, key, s);
}
int mlx_random_randint(mlx_array* res, const mlx_array low, const mlx_array high, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) {
return mlx_random_randint_ptr(res, low, high, shape, shape_num, dtype, key, s);
}
int mlx_random_seed(uint64_t seed) {
return mlx_random_seed_ptr(seed);
}
int mlx_random_split_num(mlx_array* res, const mlx_array key, int num, const mlx_stream s) {
return mlx_random_split_num_ptr(res, key, num, s);
}
int mlx_random_split(mlx_array* res_0, mlx_array* res_1, const mlx_array key, const mlx_stream s) {
return mlx_random_split_ptr(res_0, res_1, key, s);
}
int mlx_random_truncated_normal(mlx_array* res, const mlx_array lower, const mlx_array upper, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) {
return mlx_random_truncated_normal_ptr(res, lower, upper, shape, shape_num, dtype, key, s);
}
int mlx_random_uniform(mlx_array* res, const mlx_array low, const mlx_array high, const int* shape, size_t shape_num, mlx_dtype dtype, const mlx_array key , const mlx_stream s) {
return mlx_random_uniform_ptr(res, low, high, shape, shape_num, dtype, key, s);
}
mlx_stream mlx_stream_new(void) {
return mlx_stream_new_ptr();
}
mlx_stream mlx_stream_new_device(mlx_device dev) {
return mlx_stream_new_device_ptr(dev);
}
int mlx_stream_set(mlx_stream* stream, const mlx_stream src) {
return mlx_stream_set_ptr(stream, src);
}
int mlx_stream_free(mlx_stream stream) {
return mlx_stream_free_ptr(stream);
}
int mlx_stream_tostring(mlx_string* str, mlx_stream stream) {
return mlx_stream_tostring_ptr(str, stream);
}
bool mlx_stream_equal(mlx_stream lhs, mlx_stream rhs) {
return mlx_stream_equal_ptr(lhs, rhs);
}
int mlx_stream_get_device(mlx_device* dev, mlx_stream stream) {
return mlx_stream_get_device_ptr(dev, stream);
}
int mlx_stream_get_index(int* index, mlx_stream stream) {
return mlx_stream_get_index_ptr(index, stream);
}
int mlx_synchronize(mlx_stream stream) {
return mlx_synchronize_ptr(stream);
}
int mlx_get_default_stream(mlx_stream* stream, mlx_device dev) {
return mlx_get_default_stream_ptr(stream, dev);
}
int mlx_set_default_stream(mlx_stream stream) {
return mlx_set_default_stream_ptr(stream);
}
mlx_stream mlx_default_cpu_stream_new(void) {
return mlx_default_cpu_stream_new_ptr();
}
mlx_stream mlx_default_gpu_stream_new(void) {
return mlx_default_gpu_stream_new_ptr();
}
mlx_string mlx_string_new(void) {
return mlx_string_new_ptr();
}
mlx_string mlx_string_new_data(const char* str) {
return mlx_string_new_data_ptr(str);
}
int mlx_string_set(mlx_string* str, const mlx_string src) {
return mlx_string_set_ptr(str, src);
}
const char* mlx_string_data(mlx_string str) {
return mlx_string_data_ptr(str);
}
int mlx_string_free(mlx_string str) {
return mlx_string_free_ptr(str);
}
int mlx_async_eval(const mlx_vector_array outputs) {
return mlx_async_eval_ptr(outputs);
}
int mlx_checkpoint(mlx_closure* res, const mlx_closure fun) {
return mlx_checkpoint_ptr(res, fun);
}
int mlx_custom_function(mlx_closure* res, const mlx_closure fun, const mlx_closure_custom fun_vjp , const mlx_closure_custom_jvp fun_jvp , const mlx_closure_custom_vmap fun_vmap) {
return mlx_custom_function_ptr(res, fun, fun_vjp, fun_jvp, fun_vmap);
}
int mlx_custom_vjp(mlx_closure* res, const mlx_closure fun, const mlx_closure_custom fun_vjp) {
return mlx_custom_vjp_ptr(res, fun, fun_vjp);
}
int mlx_eval(const mlx_vector_array outputs) {
return mlx_eval_ptr(outputs);
}
int mlx_jvp(mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array primals, const mlx_vector_array tangents) {
return mlx_jvp_ptr(res_0, res_1, fun, primals, tangents);
}
int mlx_value_and_grad(mlx_closure_value_and_grad* res, const mlx_closure fun, const int* argnums, size_t argnums_num) {
return mlx_value_and_grad_ptr(res, fun, argnums, argnums_num);
}
int mlx_vjp(mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array primals, const mlx_vector_array cotangents) {
return mlx_vjp_ptr(res_0, res_1, fun, primals, cotangents);
}
int mlx_detail_vmap_replace(mlx_vector_array* res, const mlx_vector_array inputs, const mlx_vector_array s_inputs, const mlx_vector_array s_outputs, const int* in_axes, size_t in_axes_num, const int* out_axes, size_t out_axes_num) {
return mlx_detail_vmap_replace_ptr(res, inputs, s_inputs, s_outputs, in_axes, in_axes_num, out_axes, out_axes_num);
}
int mlx_detail_vmap_trace(mlx_vector_array* res_0, mlx_vector_array* res_1, const mlx_closure fun, const mlx_vector_array inputs, const int* in_axes, size_t in_axes_num) {
return mlx_detail_vmap_trace_ptr(res_0, res_1, fun, inputs, in_axes, in_axes_num);
}
mlx_vector_array mlx_vector_array_new(void) {
return mlx_vector_array_new_ptr();
}
int mlx_vector_array_set(mlx_vector_array* vec, const mlx_vector_array src) {
return mlx_vector_array_set_ptr(vec, src);
}
int mlx_vector_array_free(mlx_vector_array vec) {
return mlx_vector_array_free_ptr(vec);
}
mlx_vector_array mlx_vector_array_new_data(const mlx_array* data, size_t size) {
return mlx_vector_array_new_data_ptr(data, size);
}
mlx_vector_array mlx_vector_array_new_value(const mlx_array val) {
return mlx_vector_array_new_value_ptr(val);
}
int mlx_vector_array_set_data(mlx_vector_array* vec, const mlx_array* data, size_t size) {
return mlx_vector_array_set_data_ptr(vec, data, size);
}
int mlx_vector_array_set_value(mlx_vector_array* vec, const mlx_array val) {
return mlx_vector_array_set_value_ptr(vec, val);
}
int mlx_vector_array_append_data(mlx_vector_array vec, const mlx_array* data, size_t size) {
return mlx_vector_array_append_data_ptr(vec, data, size);
}
int mlx_vector_array_append_value(mlx_vector_array vec, const mlx_array val) {
return mlx_vector_array_append_value_ptr(vec, val);
}
size_t mlx_vector_array_size(mlx_vector_array vec) {
return mlx_vector_array_size_ptr(vec);
}
int mlx_vector_array_get(mlx_array* res, const mlx_vector_array vec, size_t idx) {
return mlx_vector_array_get_ptr(res, vec, idx);
}
mlx_vector_vector_array mlx_vector_vector_array_new(void) {
return mlx_vector_vector_array_new_ptr();
}
int mlx_vector_vector_array_set(mlx_vector_vector_array* vec, const mlx_vector_vector_array src) {
return mlx_vector_vector_array_set_ptr(vec, src);
}
int mlx_vector_vector_array_free(mlx_vector_vector_array vec) {
return mlx_vector_vector_array_free_ptr(vec);
}
mlx_vector_vector_array mlx_vector_vector_array_new_data(const mlx_vector_array* data, size_t size) {
return mlx_vector_vector_array_new_data_ptr(data, size);
}
mlx_vector_vector_array mlx_vector_vector_array_new_value(const mlx_vector_array val) {
return mlx_vector_vector_array_new_value_ptr(val);
}
int mlx_vector_vector_array_set_data(mlx_vector_vector_array* vec, const mlx_vector_array* data, size_t size) {
return mlx_vector_vector_array_set_data_ptr(vec, data, size);
}
int mlx_vector_vector_array_set_value(mlx_vector_vector_array* vec, const mlx_vector_array val) {
return mlx_vector_vector_array_set_value_ptr(vec, val);
}
int mlx_vector_vector_array_append_data(mlx_vector_vector_array vec, const mlx_vector_array* data, size_t size) {
return mlx_vector_vector_array_append_data_ptr(vec, data, size);
}
int mlx_vector_vector_array_append_value(mlx_vector_vector_array vec, const mlx_vector_array val) {
return mlx_vector_vector_array_append_value_ptr(vec, val);
}
size_t mlx_vector_vector_array_size(mlx_vector_vector_array vec) {
return mlx_vector_vector_array_size_ptr(vec);
}
int mlx_vector_vector_array_get(mlx_vector_array* res, const mlx_vector_vector_array vec, size_t idx) {
return mlx_vector_vector_array_get_ptr(res, vec, idx);
}
mlx_vector_int mlx_vector_int_new(void) {
return mlx_vector_int_new_ptr();
}
int mlx_vector_int_set(mlx_vector_int* vec, const mlx_vector_int src) {
return mlx_vector_int_set_ptr(vec, src);
}
int mlx_vector_int_free(mlx_vector_int vec) {
return mlx_vector_int_free_ptr(vec);
}
mlx_vector_int mlx_vector_int_new_data(int* data, size_t size) {
return mlx_vector_int_new_data_ptr(data, size);
}
mlx_vector_int mlx_vector_int_new_value(int val) {
return mlx_vector_int_new_value_ptr(val);
}
int mlx_vector_int_set_data(mlx_vector_int* vec, int* data, size_t size) {
return mlx_vector_int_set_data_ptr(vec, data, size);
}
int mlx_vector_int_set_value(mlx_vector_int* vec, int val) {
return mlx_vector_int_set_value_ptr(vec, val);
}
int mlx_vector_int_append_data(mlx_vector_int vec, int* data, size_t size) {
return mlx_vector_int_append_data_ptr(vec, data, size);
}
int mlx_vector_int_append_value(mlx_vector_int vec, int val) {
return mlx_vector_int_append_value_ptr(vec, val);
}
size_t mlx_vector_int_size(mlx_vector_int vec) {
return mlx_vector_int_size_ptr(vec);
}
int mlx_vector_int_get(int* res, const mlx_vector_int vec, size_t idx) {
return mlx_vector_int_get_ptr(res, vec, idx);
}
mlx_vector_string mlx_vector_string_new(void) {
return mlx_vector_string_new_ptr();
}
int mlx_vector_string_set(mlx_vector_string* vec, const mlx_vector_string src) {
return mlx_vector_string_set_ptr(vec, src);
}
int mlx_vector_string_free(mlx_vector_string vec) {
return mlx_vector_string_free_ptr(vec);
}
mlx_vector_string mlx_vector_string_new_data(const char** data, size_t size) {
return mlx_vector_string_new_data_ptr(data, size);
}
mlx_vector_string mlx_vector_string_new_value(const char* val) {
return mlx_vector_string_new_value_ptr(val);
}
int mlx_vector_string_set_data(mlx_vector_string* vec, const char** data, size_t size) {
return mlx_vector_string_set_data_ptr(vec, data, size);
}
int mlx_vector_string_set_value(mlx_vector_string* vec, const char* val) {
return mlx_vector_string_set_value_ptr(vec, val);
}
int mlx_vector_string_append_data(mlx_vector_string vec, const char** data, size_t size) {
return mlx_vector_string_append_data_ptr(vec, data, size);
}
int mlx_vector_string_append_value(mlx_vector_string vec, const char* val) {
return mlx_vector_string_append_value_ptr(vec, val);
}
size_t mlx_vector_string_size(mlx_vector_string vec) {
return mlx_vector_string_size_ptr(vec);
}
int mlx_vector_string_get(char** res, const mlx_vector_string vec, size_t idx) {
return mlx_vector_string_get_ptr(res, vec, idx);
}
int mlx_version(mlx_string* str_) {
return mlx_version_ptr(str_);
}