Maybe I just got unlucky, but there seems to be a cuda compatibility issue with recent versions of JAX specifically >=0.6.0 See here: https://github.com/jax-ml/jax/issues/29042 It maybe good to mention a upper limit on jax version so that things work smoothly! :)