4. Optimization

quimb supports optimizing an abtirary tensor network with respect to an arbitrary ‘loss’ function using automatic differentiation. This is encapsulated in the TNOptimizer object. Internally, this makes use of one of various autodiff libraries to compute the neccesary tensor gradients, then maps these tensor parameters into a ‘ravelled’ single real vector for optimization in an outer loop by scipy.optimize.minimize().

Note

You can also use quimb simply as a way to orchestrate operations on e.g. torch or jax arrays, and then use optimiztion libraries from those frameworks directly. This is demonstrated in the examples Using quimb within torch and Using quimb within jax, flax and optax.

The basic steps are:

  1. Define a target tensor network (or pytree[1] of TensorNetwork, Tensor, raw array, or Circuit objects).

  2. Optionally if its a single tensor network, define sets of tags that specify which tensors to optimize, keep constant, or share parameters among.

  3. Optionally define a norm_fn that takes the target and projects or constrains it to some valid form, e.g. normalizing or mapping into unitary form. This is the identity by default.

  4. Define a loss function that takes tn (or norm_fn(tn)) and returns a single real scalar to minimize.

Here we’ll demonstrate this with a PBC MPS optimization for the Heisenberg model.

%config InlineBackend.figure_formats = ['svg']
import quimb as qu
import quimb.tensor as qtn
L = 64
D = 16
pbc = True

# create a random MPS as our initial target to optimize
psi = qtn.MPS_rand_state(L, bond_dim=D, cyclic=pbc)

# create the hamiltonian MPO, this is a constant TN not to be optimized
ham = qtn.MPO_ham_heis(L, cyclic=pbc)

Next we define our norm_fn, which here just normalizes the MPS, and our loss_fn, which computes the energy of the Heisenberg model by exactly contracting an MPS-MPO-MPS overlap.

def norm_fn(psi):
    # we could always define this within the loss function, but separating it
    # out can be clearer - it's also called before returning the optimized TN
    nfact = (psi.H @ psi)**0.5
    return psi.multiply(1 / nfact, spread_over='all')


def loss_fn(psi, ham):
    b, h, k = qtn.tensor_network_align(psi.H, ham, psi)
    energy_tn = b | h | k
    return energy_tn ^ ...

We can check the initial loss value with:

loss_fn(norm_fn(psi), ham)
-0.16003514944501307

Next we supply these to a TNOptimizer object. Since we have an extra tensor object ham that is needed to compute the loss, but should not be optimized, we pass it in loss_constants, that allows it to be converted to the correct backend etc.

tnopt = qtn.TNOptimizer(
    # the tensor network we want to optimize
    psi,
    # the functions specfying the loss and normalization
    loss_fn=loss_fn,
    norm_fn=norm_fn,
    # we specify constants so that the arguments can be converted
    # to the  desired autodiff backend automatically
    loss_constants={"ham": ham},
    # the underlying algorithm to use for the optimization
    # 'l-bfgs-b' is the default and often good for fast initial progress
    optimizer="adam",
    # which gradient computation backend to use
    autodiff_backend="jax",
)
tnopt
<TNOptimizer(d=32768, backend=jax)>

Hint

You can also pass general non-numeric or tensor options in loss_kwargs.

We can see there are 32,768 parameters to optimize, which would be tricky without gradients. We are ready to start optimizing (note for backens like jax which compile the computation by default, there will be some initial overhead):

psi_opt = tnopt.optimize(1000)
-28.295518875122 [best: -28.295518875122] : : 1001it [01:03, 15.81it/s]                        

There is a simple tnopt.plot method to visualize the loss progress (note by default the first 20 points are shown on a linear plot, the rest on a log plot):

tnopt.plot(hlines={'analytic': qu.heisenberg_energy(L)})
_images/b6d03bb0ddab6a68d47cf1e0141f624f0036f802f3f549653381df7f92be6f0f.svg
(<Figure size 640x480 with 1 Axes>, <Axes: xlabel='Iteration', ylabel='Loss'>)

We can check the returned psi_opt optimized target indeed matches loss:

loss_fn(psi_opt, ham)
-28.295518668947235

Note this TN (which can be retrieved from tnopt.get_tn_opt) is a copy of the original target TN, with the optimized parameters set. It has also been passed through norm_fn so is in normalized/projected form, and converted back to numpy backed arrays.

4.1. Using tags to opt in, opt out or group tensors

There are three mutually exclusive options when it comes to specifying exactly which tensors to optimize.

  1. Opt in: specify tags that tensors must have to be optimized, any tensors without these tags are assumed to be constant.

  2. Opt out: specify a set of constant_tags that tensors must have to be constant, any tensors without these tags are assumed to be optimized.

  3. pytree: supply an arbitrary pytree of objects to use however within the norm_fn and loss_fn, in this case all tensors are assumed to be optimized.

In all cases you can supply loss_constants that are passed to the loss_fn but not optimized, this should be a dict containing arbitrary pytree values.

In the first case, you can also specify shared_tags, which we demonstrate here. Here every tensor with one of shared tags it assumed to share parameters with every other tensor with the same tag. For example, we can create a unit cell of size 2, and specify that our MPS be compsed of two repeated tensors A and B only:

for site in psi.sites:
    t = psi[site]
    if site % 2 == 0:
        t.add_tag("A")
        t.modify(data=psi[0].data)
    else:
        t.add_tag("B")
        t.modify(data=psi[1].data)

psi.normalize()
psi.equalize_norms_()
MatrixProductState(tensors=64, indices=128, L=64, max_bond=16)
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAA, _8db77cAAAAB, k0], tags={I0, A}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAB, _8db77cAAAAC, k1], tags={I1, B}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAC, _8db77cAAAAD, k2], tags={I2, A}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAD, _8db77cAAAAE, k3], tags={I3, B}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAE, _8db77cAAAAF, k4], tags={I4, A}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAF, _8db77cAAAAG, k5], tags={I5, B}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAG, _8db77cAAAAH, k6], tags={I6, A}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAH, _8db77cAAAAI, k7], tags={I7, B}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAI, _8db77cAAAAJ, k8], tags={I8, A}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAJ, _8db77cAAAAK, k9], tags={I9, B}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAK, _8db77cAAAAL, k10], tags={I10, A}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAL, _8db77cAAAAM, k11], tags={I11, B}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAM, _8db77cAAAAN, k12], tags={I12, A}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAN, _8db77cAAAAO, k13], tags={I13, B}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAO, _8db77cAAAAP, k14], tags={I14, A}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAP, _8db77cAAAAQ, k15], tags={I15, B}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAQ, _8db77cAAAAR, k16], tags={I16, A}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAR, _8db77cAAAAS, k17], tags={I17, B}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAS, _8db77cAAAAT, k18], tags={I18, A}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAT, _8db77cAAAAU, k19], tags={I19, B}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAU, _8db77cAAAAV, k20], tags={I20, A}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAV, _8db77cAAAAW, k21], tags={I21, B}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAW, _8db77cAAAAX, k22], tags={I22, A}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAX, _8db77cAAAAY, k23], tags={I23, B}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAY, _8db77cAAAAZ, k24], tags={I24, A}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAZ, _8db77cAAAAa, k25], tags={I25, B}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAa, _8db77cAAAAb, k26], tags={I26, A}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAb, _8db77cAAAAc, k27], tags={I27, B}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAc, _8db77cAAAAd, k28], tags={I28, A}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAd, _8db77cAAAAe, k29], tags={I29, B}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAe, _8db77cAAAAf, k30], tags={I30, A}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAf, _8db77cAAAAg, k31], tags={I31, B}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAg, _8db77cAAAAh, k32], tags={I32, A}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAh, _8db77cAAAAi, k33], tags={I33, B}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAi, _8db77cAAAAj, k34], tags={I34, A}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAj, _8db77cAAAAk, k35], tags={I35, B}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAk, _8db77cAAAAl, k36], tags={I36, A}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAl, _8db77cAAAAm, k37], tags={I37, B}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAm, _8db77cAAAAn, k38], tags={I38, A}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAn, _8db77cAAAAo, k39], tags={I39, B}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAo, _8db77cAAAAp, k40], tags={I40, A}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAp, _8db77cAAAAq, k41], tags={I41, B}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAq, _8db77cAAAAr, k42], tags={I42, A}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAr, _8db77cAAAAs, k43], tags={I43, B}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAs, _8db77cAAAAt, k44], tags={I44, A}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAt, _8db77cAAAAu, k45], tags={I45, B}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAu, _8db77cAAAAv, k46], tags={I46, A}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAv, _8db77cAAAAw, k47], tags={I47, B}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAw, _8db77cAAAAx, k48], tags={I48, A}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAx, _8db77cAAAAy, k49], tags={I49, B}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAy, _8db77cAAAAz, k50], tags={I50, A}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAz, _8db77cAAABA, k51], tags={I51, B}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAABA, _8db77cAAABB, k52], tags={I52, A}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAABB, _8db77cAAABC, k53], tags={I53, B}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAABC, _8db77cAAABD, k54], tags={I54, A}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAABD, _8db77cAAABE, k55], tags={I55, B}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAABE, _8db77cAAABF, k56], tags={I56, A}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAABF, _8db77cAAABG, k57], tags={I57, B}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAABG, _8db77cAAABH, k58], tags={I58, A}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAABH, _8db77cAAABI, k59], tags={I59, B}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAABI, _8db77cAAABJ, k60], tags={I60, A}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAABJ, _8db77cAAABK, k61], tags={I61, B}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAABK, _8db77cAAABL, k62], tags={I62, A}),backend=numpy, dtype=float64, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAABL, _8db77cAAAAA, k63], tags={I63, B}),backend=numpy, dtype=float64, data=...
psi.draw(["A", "B"], figsize=(4, 4))
_images/31e052ea8a82046011f0d57a732bad00af0d4a567ca0ffacdd56f371f634f42b.svg
tnopt = qtn.TNOptimizer(
    psi,
    loss_fn=loss_fn,
    norm_fn=norm_fn,
    loss_constants={"ham": ham},
    optimizer="adam",
    autodiff_backend="jax",
    # only optimize the tensors with these tags (in this case all)
    tags=["A", "B"],
    # within those, group all with each of these tags together
    shared_tags=["A", "B"],
)
tnopt
<TNOptimizer(d=1024, backend=jax)>

You can see the dramatic reduction in the number of parameters to optimize, from 32,768 to 1,024. The optimization proceeds in the same way as before, but now the tensors A and B are constrained to be the same.

tnopt.optimize(1000)
  0%|          | 0/1000 [00:00<?, ?it/s]
-28.297351837158 [best: -28.297351837158] : : 1001it [00:41, 23.90it/s]                        
MatrixProductState(tensors=64, indices=128, L=64, max_bond=16)
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAA, _8db77cAAAAB, k0], tags={I0, A}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAB, _8db77cAAAAC, k1], tags={I1, B}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAC, _8db77cAAAAD, k2], tags={I2, A}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAD, _8db77cAAAAE, k3], tags={I3, B}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAE, _8db77cAAAAF, k4], tags={I4, A}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAF, _8db77cAAAAG, k5], tags={I5, B}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAG, _8db77cAAAAH, k6], tags={I6, A}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAH, _8db77cAAAAI, k7], tags={I7, B}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAI, _8db77cAAAAJ, k8], tags={I8, A}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAJ, _8db77cAAAAK, k9], tags={I9, B}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAK, _8db77cAAAAL, k10], tags={I10, A}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAL, _8db77cAAAAM, k11], tags={I11, B}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAM, _8db77cAAAAN, k12], tags={I12, A}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAN, _8db77cAAAAO, k13], tags={I13, B}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAO, _8db77cAAAAP, k14], tags={I14, A}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAP, _8db77cAAAAQ, k15], tags={I15, B}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAQ, _8db77cAAAAR, k16], tags={I16, A}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAR, _8db77cAAAAS, k17], tags={I17, B}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAS, _8db77cAAAAT, k18], tags={I18, A}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAT, _8db77cAAAAU, k19], tags={I19, B}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAU, _8db77cAAAAV, k20], tags={I20, A}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAV, _8db77cAAAAW, k21], tags={I21, B}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAW, _8db77cAAAAX, k22], tags={I22, A}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAX, _8db77cAAAAY, k23], tags={I23, B}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAY, _8db77cAAAAZ, k24], tags={I24, A}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAZ, _8db77cAAAAa, k25], tags={I25, B}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAa, _8db77cAAAAb, k26], tags={I26, A}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAb, _8db77cAAAAc, k27], tags={I27, B}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAc, _8db77cAAAAd, k28], tags={I28, A}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAd, _8db77cAAAAe, k29], tags={I29, B}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAe, _8db77cAAAAf, k30], tags={I30, A}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAf, _8db77cAAAAg, k31], tags={I31, B}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAg, _8db77cAAAAh, k32], tags={I32, A}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAh, _8db77cAAAAi, k33], tags={I33, B}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAi, _8db77cAAAAj, k34], tags={I34, A}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAj, _8db77cAAAAk, k35], tags={I35, B}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAk, _8db77cAAAAl, k36], tags={I36, A}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAl, _8db77cAAAAm, k37], tags={I37, B}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAm, _8db77cAAAAn, k38], tags={I38, A}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAn, _8db77cAAAAo, k39], tags={I39, B}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAo, _8db77cAAAAp, k40], tags={I40, A}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAp, _8db77cAAAAq, k41], tags={I41, B}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAq, _8db77cAAAAr, k42], tags={I42, A}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAr, _8db77cAAAAs, k43], tags={I43, B}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAs, _8db77cAAAAt, k44], tags={I44, A}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAt, _8db77cAAAAu, k45], tags={I45, B}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAu, _8db77cAAAAv, k46], tags={I46, A}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAv, _8db77cAAAAw, k47], tags={I47, B}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAw, _8db77cAAAAx, k48], tags={I48, A}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAx, _8db77cAAAAy, k49], tags={I49, B}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAy, _8db77cAAAAz, k50], tags={I50, A}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAAAz, _8db77cAAABA, k51], tags={I51, B}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAABA, _8db77cAAABB, k52], tags={I52, A}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAABB, _8db77cAAABC, k53], tags={I53, B}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAABC, _8db77cAAABD, k54], tags={I54, A}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAABD, _8db77cAAABE, k55], tags={I55, B}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAABE, _8db77cAAABF, k56], tags={I56, A}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAABF, _8db77cAAABG, k57], tags={I57, B}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAABG, _8db77cAAABH, k58], tags={I58, A}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAABH, _8db77cAAABI, k59], tags={I59, B}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAABI, _8db77cAAABJ, k60], tags={I60, A}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAABJ, _8db77cAAABK, k61], tags={I61, B}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAABK, _8db77cAAABL, k62], tags={I62, A}),backend=numpy, dtype=float32, data=...
Tensor(shape=(16, 16, 2), inds=[_8db77cAAABL, _8db77cAAAAA, k63], tags={I63, B}),backend=numpy, dtype=float32, data=...
tnopt.plot(hlines={'analytic': qu.heisenberg_energy(L)})
_images/fa1c666b7cc0e48d2f7303ad06255211c7be99d75d3efa8b6160a8ab0010a668.svg
(<Figure size 640x480 with 1 Axes>, <Axes: xlabel='Iteration', ylabel='Loss'>)

The reduction in parameters also helps the optimization converge faster.

As a sanity check, we can explicitly visualize the first four tensors in the optimized MPS, and see that they do indeed repeat:

psi_opt = tnopt.get_tn_opt()
psi_opt[:4].visualize_tensors("row")
_images/31f741f8c75e8252e3f889caa6a293afa16220909fd7701adac9c08bcad01fb4.svg

4.2. Optimizing Circuit objects

One special case of optimizing pytrees, is when Circuit objects are encountered. These contain a tensor network representation of the quantum circuit, but only the gates/tensors which have been specified as parametrized are optimized.

import autoray as ar

rng = ar.do('random.default_rng', 42, like="numpy")

circ = qtn.Circuit(2)
circ.u3(*rng.uniform(size=3, high=2 * qu.pi), 0, parametrize=True)
circ.u3(*rng.uniform(size=3, high=2 * qu.pi), 1, parametrize=True)
circ.cnot(0, 1)
circ.u3(*rng.uniform(size=3, high=2 * qu.pi), 0, parametrize=True)
circ.u3(*rng.uniform(size=3, high=2 * qu.pi), 1, parametrize=True)

H = qu.ham_heis(2).astype("complex128")

def loss(circ, H):
    en = circ.local_expectation(H, (0, 1), simplify_sequence="ADCRS")
    # we use `autoray.do` to allow arbitrary autodiff backends
    return ar.do("real", en)

tnopt = qtn.TNOptimizer(
    circ,
    loss,
    loss_constants=dict(H=H),
    # because we are using dynamic (entry dependent) simplification
    autodiff_backend="autograd",
)
circ_opt = tnopt.optimize(10)
-0.749999999208 [best: -0.749999999208] : : 13it [00:00, 39.17it/s]                      

The returned circuit now has the optimized parameters set.

circ_opt.gates
(<Gate(label=U3, params=[4.71234134 2.55777369 5.39472984], qubits=(0,), parametrize=True))>,
 <Gate(label=U3, params=[3.14157063 0.52032894 6.13001603], qubits=(1,), parametrize=True))>,
 <Gate(label=CNOT, params=[], qubits=(0, 1))>,
 <Gate(label=U3, params=[4.20005643 5.20517656 0.60518082], qubits=(0,), parametrize=True))>,
 <Gate(label=U3, params=[2.08314126 2.06360383 6.30453863], qubits=(1,), parametrize=True))>)
loss(circ_opt, H)
-0.7499999992075422

But the initial state and constant gates have not been changed, and becuase the parametrized tensors are manifestly unitary, it is always normalized.

Note

Note also, that because we used simplify_sequence="ADCRS", which performs TN simplifications that depend on the tensor entries, we cannot use a statically compiled autodiff backend like jax. Instead we use an ‘eager’ library autograd here. Another option would be to instead use simplify_sequence="R".

If you want fine control over which gates to optimize and share parameters among in the circuit, it is best to extract the TN representation first, and optimize it directly. You can then call circ.update_params_from(tn) to update the circuit parameters from an optimized tensor network.

4.3. Optimizing PTensor objects

The circuit object paramterized gates behind the scenes use a ‘paramterized’ tensor object, PTensor, which holds a PArray. This is a generalization of Tensor whose data is defined by a function and some parameters (kind of like a local norm_fn that is always applied). You can use these directly for even finer control.

Here we show a very roundabout way of trying to diagonalize a non-symmetric matrix using two orthogonal matrices U and V.

# each parametrized tensor has implicit `data = fn(params)`

Ua = qtn.PTensor(
    # project into isometric / unitary / isometric form
    fn=qtn.decomp.isometrize_cayley,
    # parameters here are some arbitrary array
    params=qu.identity(10, dtype="float64"),
    inds=["w", "x"],
    tags="Ua",
)

G = qtn.Tensor(
    data=qu.randn((10, 10)),
    inds=['x', 'y'],
    tags="G",
)

Ub = qtn.PTensor(
    fn=qtn.decomp.isometrize_cayley,
    params=qu.identity(10, dtype="float64"),
    inds=["y", "z"],
    tags="Ub",
)

(Ua | G | Ub).draw(["Ua", "G", "Ub"])
(Ua | G | Ub).contract().visualize()
_images/56741e59f4ff5d22666f737cb1b53839b25441c456a79a1b59c74def0cac5993.svg _images/5183ab8e74bacfb88d78b9b048c964a8f29c158df79ece92b9914347d25acebb.svg
(<Figure size 500x500 with 1 Axes>, <Axes: >)
def loss(target, G):
    Ua, Ub = target["Ua"], target["Ub"]
    UGU = (Ua | G | Ub).contract(output_inds=['w', 'z']).data
    # minimize off-diagonal elements of UGU
    return ar.do("linalg.norm", UGU - ar.do('diag', ar.do('diag', UGU)))

We’ll also here make use of automatic hessian-vector product computation by some backends (e.g. jax), which is can be used with second order optimization methods like Newton-CG.

tnopt = qtn.TNOptimizer(
    {"Ua": Ua, "Ub": Ub},
    loss,
    loss_constants=dict(G=G),
    optimizer="newton-cg",
    autodiff_backend="jax",
)
to = tnopt.optimize(1000, hessp=True)
+0.000001728355 [best: +0.000001728355] :  60%|██████    | 601/1000 [00:01<00:01, 307.78it/s]
tnopt.plot()
_images/2d05c94c3c6944f1e5114007cc5f1609cbf1477f7a07700c9afba6f9df0c2e7c.svg
(<Figure size 640x480 with 1 Axes>, <Axes: xlabel='Iteration', ylabel='Loss'>)
Uao, Ubo = to["Ua"], to["Ub"]
Uao
PTensor(shape=(10, 10), inds=[w, x], tags={Ua}),backend=numpy, dtype=None, data=array([[ 0.32642075, -0.6051291 , 0.2978906 , 0.39854315, 0.02511937, 0.0838748 , -0.29238603, 0.0739949 , -0.21712632, -0.36594826], [ 0.0908875 , 0.33829898, 0.42404962, -0.4448276 , 0.18260145, 0.42017654, 0.08901135, 0.26795745, -0.29496846, -0.35068703], [-0.19292217, -0.27753505, 0.0356112 , 0.15946661, 0.47263956, 0.2622383 , 0.54696834, -0.28224143, -0.32863796, 0.282954 ], [ 0.07516661, 0.08561171, -0.151499 , 0.3938459 , -0.45643058, 0.49859524, 0.24238792, 0.48964697, -0.01888227, 0.23057617], [ 0.09385245, -0.0597567 , -0.7957913 , -0.08307468, 0.39757654, 0.09958041, -0.1210546 , 0.2673096 , -0.14329785, -0.26983204], [-0.34093434, -0.35484943, -0.00559698, -0.38959286, -0.08984588, 0.3742761 , -0.5319735 , 0.02843514, -0.0771872 , 0.41001788], [-0.35343474, -0.46307853, -0.07277285, -0.31267455, -0.3303163 , -0.0156951 , 0.4368167 , 0.0749092 , 0.23010063, -0.44593632], [ 0.18945608, -0.24993688, 0.17366748, -0.21469343, 0.22132577, -0.4433413 , 0.16982655, 0.63924885, 0.03606365, 0.3755888 ], [ 0.46199533, -0.12276153, 0.02312034, -0.13623571, 0.23001385, 0.3858519 , 0.06653047, -0.15374139, 0.71851975, 0.07920738], [ 0.5856843 , -0.11639551, -0.19472294, -0.37635726, -0.39644542, -0.05807046, 0.16411862, -0.30522007, -0.40060037, 0.15082498]], dtype=float32)
# still orthogonal
Uao.norm()**2, Ubo.norm()**2
(10.000000000000002, 10.000000000000002)
# check gauged G
(Uao | G | Ubo).contract().visualize()
_images/ec4ffde01c2d80062028560d32beccd7e9932827923c2fb0161549b5c0f59755.svg
(<Figure size 500x500 with 1 Axes>, <Axes: >)