jax
Jax on cpu
To utilize cpu version of jax is pretty easy, just
is ok. Remember using pip
provided in your conda env to avoid package clashes and other nasty stuff.
Jax has many sharp bits you need to pay attention to, like how to enable 64bit floats and how to utilize random number generators, please refer to jax documentation for these details.
Jax on GPU
To enable jax on gpu is another story, it is somehow involved and error prone, but I will elaborate the exact way that you can enjoy gpu version of jax in our HPC below.
TL;DR
Details
Last updated
Was this helpful?