Skip to content

Instantly share code, notes, and snippets.

@Fraer
Last active May 19, 2025 15:06
Show Gist options
  • Select an option

  • Save Fraer/b380c9b53a49f1c79ada24624a4fba30 to your computer and use it in GitHub Desktop.

Select an option

Save Fraer/b380c9b53a49f1c79ada24624a4fba30 to your computer and use it in GitHub Desktop.
Direct Access to MLX Library
import com.sun.jna.Library;
import com.sun.jna.Native;
import com.sun.jna.Pointer;
import com.sun.jna.Structure;
import com.sun.jna.ptr.PointerByReference;
import com.sun.jna.Memory;
import com.sun.jna.NativeLong;
import java.util.Arrays;
/**
* Example of using JNA to interface directly with MLX library on Apple Silicon
* without an intermediate C++ wrapper
*/
public class DirectMLXExample {
/**
* Direct MLX library interface definition
* Note: This approach requires knowing the exact C++ mangled names and signatures
*/
public interface MLXLibrary extends Library {
// Load the MLX library directly
MLXLibrary INSTANCE = Native.load("libmlx.dylib", MLXLibrary.class);
// Example of directly accessing a C++ function with its mangled name
// This is the mangled name for a specific constructor signature of mlx::core::array
// The exact mangled name will depend on your MLX version and compiler
Pointer _ZN3mlx4core5arrayC1EPKfRKSt6vectorIiSaIiEENS0_5DtypeE(
Pointer data, // float* data
Pointer shape_vector, // const std::vector<int>& shape
int dtype // mlx::core::Dtype dtype
);
// Function to create zeros array - mangled name for mlx::core::zeros
Pointer _ZN3mlx4core5zerosERKSt6vectorIiSaIiEENS0_5DtypeE(
Pointer shape_vector, // const std::vector<int>& shape
int dtype // mlx::core::Dtype dtype
);
// Addition operator - mangled name for operator+
Pointer _ZNK3mlx4core5arrayplERKS1_(
Pointer self, // const mlx::core::array* self
Pointer other // const mlx::core::array& other
);
// Evaluation function - mangled name for mlx::core::eval
void _ZN3mlx4core4evalERNS0_5arrayE(
Pointer array // mlx::core::array& array
);
// Function to get array data as host memory
Pointer _ZNK3mlx4core5array7tohostIFaEEv(
Pointer self // const mlx::core::array* self
);
// Function to get array size
int _ZNK3mlx4core5array4sizeEv(
Pointer self // const mlx::core::array* self
);
}
// Simple std::vector<int> representation for JNA
public static class StdVector extends Structure {
public Pointer _M_start; // Pointer to first element
public Pointer _M_finish; // Pointer to one past last element
public Pointer _M_end_of_storage; // Pointer to end of allocated storage
public StdVector() {
super();
}
@Override
protected java.util.List<String> getFieldOrder() {
return Arrays.asList("_M_start", "_M_finish", "_M_end_of_storage");
}
// Create a std::vector<int> with given values
public static Pointer createIntVector(int[] values) {
// Allocate memory for the vector structure
StdVector vector = new StdVector();
// Allocate memory for the vector data
Memory data = new Memory(values.length * 4); // 4 bytes per int
for (int i = 0; i < values.length; i++) {
data.setInt(i * 4, values[i]);
}
// Set vector fields
vector._M_start = data;
vector._M_finish = data.share(values.length * 4);
vector._M_end_of_storage = vector._M_finish;
// Return pointer to the vector
vector.write();
return vector.getPointer();
}
}
/**
* Main method demonstrating direct MLX usage through JNA
*/
public static void main(String[] args) {
MLXLibrary lib = MLXLibrary.INSTANCE;
try {
// Create shape vector [2, 2]
Pointer shapeVector = StdVector.createIntVector(new int[]{2, 2});
// Create data for a 2x2 matrix
float[] data = {1.0f, 2.0f, 3.0f, 4.0f};
Memory dataMemory = new Memory(data.length * 4); // 4 bytes per float
for (int i = 0; i < data.length; i++) {
dataMemory.setFloat(i * 4, data[i]);
}
// Constants for MLX Dtype (0 = float32)
final int MLX_FLOAT32 = 0;
// Create MLX array directly
Pointer array1 = lib._ZN3mlx4core5arrayC1EPKfRKSt6vectorIiSaIiEENS0_5DtypeE(
dataMemory, shapeVector, MLX_FLOAT32);
// Create another array with zeros
Pointer array2 = lib._ZN3mlx4core5zerosERKSt6vectorIiSaIiEENS0_5DtypeE(
shapeVector, MLX_FLOAT32);
// Add the arrays
Pointer sumArray = lib._ZNK3mlx4core5arrayplERKS1_(array1, array2);
// Evaluate the result
lib._ZN3mlx4core4evalERNS0_5arrayE(sumArray);
System.out.println("Operation completed successfully");
// Note: Memory management is extremely difficult with this approach
// We would need proper destructors and memory cleanup which is complex
// when working directly with C++ objects from JNA
} catch (Exception e) {
System.err.println("Error accessing MLX library directly: " + e.getMessage());
e.printStackTrace();
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment