4. Optimization

import quimb as qu
import quimb.tensor as qtn
L = 64
D = 16
pbc = True

psi = qtn.MPS_rand_state(L, bond_dim=D, cyclic=pbc)
ham = qtn.MPO_ham_heis(L, cyclic=pbc)
def norm_fn(psi):
    nfact = (psi.H @ psi)**0.5
    return psi.multiply(1 / nfact, spread_over='all')


def loss_fn(psi, ham):
    b, h, k = qtn.tensor_network_align(psi.H, ham, psi)
    energy_tn = b | h | k
    return energy_tn ^ ...
tnopt = qtn.TNOptimizer(
    # the tensor network we want to optimize
    psi,
    # the functions specfying the loss and normalization
    loss_fn=loss_fn,
    norm_fn=norm_fn,
    # we specify constants so that the arguments can be converted
    # to the  desired autodiff backend automatically
    loss_constants={"ham": ham},
    # the underlying algorithm to use for the optimization
    # 'l-bfgs-b' is the default and often good for fast initial progress
    optimizer="adam",
    autodiff_backend="jax",
)
tnopt
<TNOptimizer(d=32768, backend=jax)>
tnopt.optimize(1000)
-28.294601440430 [best: -28.294666290283] : : 1001it [00:38, 26.15it/s]                        
<MatrixProductState(tensors=64, indices=128, L=64, max_bond=16)>
tnopt.plot(hlines={'analytic': qu.heisenberg_energy(L)})
(<Figure size 640x480 with 1 Axes>,
 <AxesSubplot:xlabel='Iteration', ylabel='Loss'>)
_images/8c5cbc973caedd28d34eb64bc86d4c65bb7ca822c9cbadd3ab173ba50ebe31d8.svg

4.1. Using tags to opt in, opt out or group tensors

for site in psi.sites:
    t = psi[site]
    if site % 2 == 0:
        t.add_tag("A")
        t.modify(data=psi[0].data)
    else:
        t.add_tag("B")
        t.modify(data=psi[1].data)

psi.normalize()
psi.equalize_norms_()
<MatrixProductState(tensors=64, indices=128, L=64, max_bond=16)>
psi.draw(["A", "B"], figsize=(4, 4))
_images/87c8d1f10dd08a01f415e4a01a66eaf23fc81d2193a7bc651f9863acdcb29b79.svg
tnopt = qtn.TNOptimizer(
    psi,
    loss_fn=loss_fn,
    norm_fn=norm_fn,
    loss_constants={"ham": ham},
    optimizer="adam",
    autodiff_backend="jax",
    # only optimize the tensors with these tags (in this case all)
    tags=["A", "B"],
    # within those, group all with each of these tags together
    shared_tags=["A", "B"],
)
tnopt
<TNOptimizer(d=1024, backend=jax)>
tnopt.optimize(1000)
-28.304227828979 [best: -28.304290771484] : : 1001it [00:28, 35.30it/s]                        
<MatrixProductState(tensors=64, indices=128, L=64, max_bond=16)>
tnopt.plot(hlines={'analytic': qu.heisenberg_energy(L)})
(<Figure size 640x480 with 1 Axes>,
 <AxesSubplot:xlabel='Iteration', ylabel='Loss'>)
_images/abecf84681393cc003e66fa731f8760127d58bb5319fac2e4076757e8049ed65.svg
psi_opt = tnopt.get_tn_opt()
psi_opt[0].visualize(legend=True)[0]
_images/cd9fe93d8a342434fab43c8588cceb5b71e4d31564b99608b1ec88667e5900ac.svg
psi_opt[1].visualize(legend=True)[0]
_images/5fcb9157bf7155dc685cb03ab6d430a4a976141ba90eaf2fecd1118c96328c84.svg
loss_fn(psi_opt, ham)
-28.30422033705775