apply jax jit for the environments and methods
JAX transformation and compilation are designed to work only on Python functions that are functionally pure (https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html)
Therefore it is recommended to change the code to be "functional oriented" (see: https://github.com/google/jax/issues/1567)
Before we do that, let us try if the just in time compilation provide us a significant change for the sde and butan environments
-
double well -
butan -
euler-marujama -
test ct improvement with a CPU -
test ct improvement with a GPU