Scala ❤️ apple MLX?

I got this (slightly odd) script working on a M2 Mac.

https://github.com/Quafadas/vecxt/blob/gonative/experiments/src/mlx.scala

Which is more-or-less a translation of
https://github.com/ml-explore/mlx-c/blob/main/examples/example.c

MLX is apples (apple silicon) numpy/torch-like framework.

It went through java’s project Panama (all a bit “build from source” - both jextract and mlx-c), but, to my astonishment, works. Here’s the output;

=== Testing MetalDeviceInfo wrapper ===
Info: (applegpu_g14s,17179869184,22906503168,34359738368)

=== String creation with data test ===
Created string data: Hello, MLX!

=== Array creation test ===
Created MLX stream: Stream(Device(cpu, 0), 1)
Created MLX stream: Stream(Device(gpu, 0), 0)
Created MLX arrays:
array([[1, 2, 3],
       [4, 5, 6]], dtype=float32)
array([[5, 2, 7],
       [45, 5.5, 6]], dtype=float32)
Added on CPU:
array([[6, 4, 10],
       [49, 10.5, 12]], dtype=float32)
Added on GPU:
array([[6, 4, 10],
       [49, 10.5, 12]], dtype=float32)

Which is kind of cool as it proves you can get at your apple silicon’s GPU compute through JVM scala.

The hypothesis is that Scala3 could be a cool place to explore MLX;

  • given Arena’s make memory allocation convenient
  • opqaque type MlxArray = MemorySegment make it possible to retain a nice low-runtime-cost type safe API.

I guess the point of posting here would be to see if there is someone else interested in contributing to an exploration of this idea - I’d carve what I have out of its existing chaos.

It’s a niche in a niche (apple silicon) in a niche (people interested in ML) in a niche (who want to explore something other than python).. but you never know, maybe there is someone!

As I’m never likely to have a beefy Nvidia card to futze around with, I figured Id investigate making the most of what I already have…

10 Likes

See also:

2 Likes

Hi there,
this is interesting! May I know how you implement the MLX part?
Songpeng

I don’t know if I’ve understood the question 100%, but I can sketch the process…

MLX has C++, C, swift and Python (<- for obvious reasons) API’s. To the best of my knowledge, we can’t plug directly into any of this from the JVM.

However, the C ABI is interesting, because Orcale’s project Panama seems to be designed, with exactly this sort of interop in mind.

Step 1 - MLX-C

Install mlx-c using the instructions in the repo locally.

Verify: Run the examples on the doc site in C/++ or whatever that repo uses

Step 2 - Jextract

As far as I can tell, Oracle don’t (yet?) publish jextract. Clone jextract and follow the instructions to build a jextract executable. You might need to read its docs.

Verify: That jextract you built exec works for a sample cases in the jexctract repo.

Step 3 - Generate Bindings to MLX-C

For me, that looked like;

 /Users/simon/Code/jextract-1/build/jextract/bin/jextract \
    -t mlx \
    --use-system-load-library \
    -l mlxc \
    --output generated/src \
    --include-dir /Users/simon/Code/mlx-c/ \
    /Users/simon/Code/mlx-c/mlx/c/mlx.h

Replace machine dependant paths (<- don’t hate me) as necessary.

Verify: That process completed successfully

Step 4 - Polish Bindings

OOTB those bindings compiled fine in java but not scala. My recollection, was that I had to delete the wait function binding and search and replace instances of $ in the bindings.

jextract will generate you a lot of files.

https://github.com/Quafadas/vecxt/tree/gonative/generated/src/mlx

Verify: That these compile with Java, and then scala.

Step 5 - Run some trivial example

You’ll need to set the right flags - in particular tell it where to find the mlx-c build.
https://github.com/Quafadas/vecxt/blob/16b109b160be4050565c9b72f25a0ab8575525bc/experiments/package.mill#L37 and allow the JVM to access native memory.

I didnt’ find it very easy, to get this first thing working. The generated bindings are somewhat arcane, and I settled on this;

https://github.com/Quafadas/vecxt/blob/gonative/experiments/src/MetalDeviceInfo.scala as a trivial starting point.

https://github.com/Quafadas/vecxt/blob/16b109b160be4050565c9b72f25a0ab8575525bc/experiments/src/mlx.scala#L74

Future

This is where I stopped. My observations so far;

  • jextract generated bindings are a PITA to work with. I made an AI write nice scala API’s for them.
  • opaque types and memory segments seems to be a beautiful match
  • I really want to try sn-bindgen on this, but… time
  • Getting this to anything approaching “useable” is an absolutely terrific amount of work that is way beyond my hobby time allocation…

It’s a cool POC though. It should all work (reasonably easily, just do step 2, the rest should be done for you) on the gonative branch of vecxt.

Sorry for massive post. I hope it helps - did it answer the Q? Also happy to chat on discord.

1 Like

Thank you so much for your response. I thought you use Scala Native, but it turns out that there are some tools to generate the Java (or scala) API bindings.