quimb.experimental.autojittn

Decorator for automatically just in time compiling tensor network functions.

TODO:

- [ ] go via an intermediate pytree / array function, that could be shared
      e.g. with the TNOptimizer class.

Classes

AutojittedTN

Class to hold the autojit_tn decorated function callable.

Functions

autojit_tn([fn, decorator, check_inputs])

Decorate a tensor network function to be just in time compiled / traced.

Module Contents

class quimb.experimental.autojittn.AutojittedTN(fn, decorator=ar.autojit, check_inputs=True, **decorator_opts)

Class to hold the autojit_tn decorated function callable.

setup_fn(tn, *args, **kwargs)
__call__(tn, *args, **kwargs)
quimb.experimental.autojittn.autojit_tn(fn=None, decorator=ar.autojit, check_inputs=True, **decorator_opts)

Decorate a tensor network function to be just in time compiled / traced. This traces solely array operations resulting in a completely static computational graph with no side-effects. The resulting function can be much faster if called repeatedly with only numeric changes, or hardware accelerated if a library such as jax is used.

Parameters:
  • fn (callable) – The function to be decorated. It should take as its first argument a TensorNetwork and return either act inplace on it or return a raw scalar or array.

  • decorator (callable) – The decorator to use to wrap the underlying array function. For example jax.jit. Defaults to autoray.autojit.

  • check_inputs (bool, optional) – Whether to check the inputs to the function every call to see if a new compiled function needs to be generated. If False the same compiled function will be used for all inputs which might be incorrect. Defaults to True.

  • decorator_opts – Options to pass to the decorator, e.g. backend for autoray.autojit.