jax
Jax on cpu
To utilize cpu version of jax is pretty easy, just
is ok. Remember using pip
provided in your conda env to avoid package clashes and other nasty stuff.
Jax has many sharp bits you need to pay attention to, like how to enable 64bit floats and how to utilize random number generators, please refer to jax documentation for these details.
Jax on GPU
To enable jax on gpu is another story, it is somehow involved and error prone, but I will elaborate the exact way that you can enjoy gpu version of jax in our HPC below.
TL;DR
Details
Nivdia is notorious for various versions of GPU drivers, cuda, and cudnn and how they are conflict or compatible with each other. See this page for some compatible trios. But the list is definitely not the whole story since many driver versions are omitted and obviously some trios not in the list also works. So…, it is a headache anyway.
In our HPC setup, GPU driver in master is 418 and in c9 is 430. And as I tests with jaxlib, cuda 10.1+ cudnn 7.6 is a workable combination for GPUs in both machines. Note cuda 10.0+ cudnn 7.5 fails when import jax
with the error complaining that ImportError: .../jaxlib/xla_extension.so: symbol cudnnSetCTCLossDescriptorEx version libcudnn.so.7 not defined in file libcudnn.so.7 with link time reference
. See this issue. Such error are often indicating that cuda and cudnn combinations fail to interact with current GPU driver versions and XLA versions. It is always hard to figure out which exactly combinations of drivers, cuda, cudnn, and jaxlib works. I only give the solution in our HPC: GPU driver version 418 and 430, jaxlib 0.1.47 + cuda 10.1.243 + cudnn 7.6.5.32 works. Other combinations? sorry, I dont know, try on your own risk.
Even after you have configured the right combination of cuda, cudnn (provided by spack) and jaxlib (which is installed via instruction here), you may meet another error when actually running some code.
The error is RuntimeError: Internal: libdevice not found at ./libdevice.10.bc
, see details in this issue. The origin of this error is that XLA design is not so smart, it cannot find cuda installation beyond default path when not specified. So one need to export XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda
before actually running jax. In our case, the exact command is export XLA_FLAGS=--xla_gpu_cuda_data_dir=/home/ubuntu/spack/opt/spack/linux-ubuntu18.04-x86_64/gcc-7.4.0/cuda-10.1.243-ohxd3xdnjd3ayvjdi2ku7dtam643l7vd
.
Last updated