quimb.experimental.autojittn

Decorator for automatically just in time compiling tensor network functions.

TODO: - [ ] check and cache on input shapes

Classes

Functions

try_and_get_params(x)

autojit_tn([fn, decorator])

Decorate a tensor network function to be just in time compiled / traced.

Module Contents

quimb.experimental.autojittn.try_and_get_params(x)
class quimb.experimental.autojittn.AutojittedTN(fn, decorator=ar.autojit, **decorator_opts)
fn
jit_fn = None
decorator
decorator_opts
_setup(*args, **kwargs)
__call__(*args, backend=None, **kwargs)
quimb.experimental.autojittn.autojit_tn(fn=None, decorator=ar.autojit, **decorator_opts)

Decorate a tensor network function to be just in time compiled / traced. This traces solely array operations resulting in a completely static computational graph with no side-effects. The resulting function can be much faster if called repeatedly with only numeric changes, or hardware accelerated if a library such as jax is used.

Parameters:
  • fn (callable) – The function to be decorated.

  • decorator (callable) – The decorator to use to wrap the underlying array function. For example jax.jit. Defaults to autoray.autojit.

  • decorator_opts – Options to pass to the decorator, e.g. backend for autoray.autojit.