10. Tensor Network Training of Quantum Circuits

Here we’ll run through constructing a tensor network of an ansatz quantum circuit, then training certain ‘parametrizable’ tensors representing quantum gates in that tensor network to replicate the behaviour of a target unitary.

%config InlineBackend.figure_formats = ['svg']
import quimb as qu
import quimb.tensor as qtn

10.1. The Ansatz Circuit

First we set up the ansatz circuit and extract the tensor network. Key here is that when we supply parametrize=True to the 'U3' gate call, it injects a PTensor into the network, which lazily represents its data array with a function and set of parameters. Later, when the optimizer sees this it then knows to optimize the parameters rather than the array itself.

def single_qubit_layer(circ, gate_round=None):
    """Apply a parametrizable layer of single qubit ``U3`` gates.
    """
    for i in range(circ.N):
        # initialize with random parameters
        params = qu.randn(3, dist='uniform')
        circ.apply_gate(
            'U3', *params, i, 
            gate_round=gate_round, parametrize=True)
        
def two_qubit_layer(circ, gate2='CZ', reverse=False, gate_round=None):
    """Apply a layer of constant entangling gates.
    """
    regs = range(0, circ.N - 1)
    if reverse:
        regs = reversed(regs)
    
    for i in regs:
        circ.apply_gate(
            gate2, i, i + 1, gate_round=gate_round)

def ansatz_circuit(n, depth, gate2='CZ', **kwargs):
    """Construct a circuit of single qubit and entangling layers.
    """
    circ = qtn.Circuit(n, **kwargs)
    
    for r in range(depth):
        # single qubit gate layer
        single_qubit_layer(circ, gate_round=r)
        
        # alternate between forward and backward CZ layers
        two_qubit_layer(
            circ, gate2=gate2, gate_round=r, reverse=r % 2 == 0)
        
    # add a final single qubit layer
    single_qubit_layer(circ, gate_round=r + 1)
    
    return circ

The form of the 'U3' gate (which generalizes all possible single qubit gates) can be seen here - U_gate(). Now we are ready to instantiate a circuit:

n = 6
depth = 9
gate2 = 'CZ'

circ = ansatz_circuit(n, depth, gate2=gate2)
circ
<Circuit(n=6, num_gates=105, gate_opts={'contract': 'auto-split-gate', 'propagate_tags': 'register'})>

We can extract just the unitary part of the circuit as a tensor network like so:

V = circ.uni
/media/johnnie/Storage2TB/Sync/dev/python/quimb/quimb/tensor/circuit.py:1194: FutureWarning: In future the tensor network returned by ``circ.uni`` will not be transposed as it is currently, to match the expectation from ``U = circ.uni.to_dense()`` behaving like ``U @ psi``. You can retain this behaviour with ``circ.get_uni(transposed=True)``.
  warnings.warn(

You can see it already has various tags simultaneously identifying different structures:

# types of gate
V.draw(color=['U3', gate2], show_inds=True)
../_images/4c6e91743bd01e52b427d423b46cd9ac02d676c2dbf41012a81eeaf61c2b892f.svg
# layers of gates
V.draw(color=[f'ROUND_{i}' for i in range(depth + 1)], show_inds=True)
../_images/87a44aa18894bb60966cc14d20f2b7e25b88a42e08c343016b3d54b86249b896.svg
# what register each tensor is 'above'
V.draw(color=[f'I{i}' for i in range(n)], show_inds=True)
../_images/6c84bddb247de460512a3c7daddae840aa93aaafe6246bc1392c1436b7b15d1e.svg
# a unique tag for per gate applied
V.draw(color=[f'GATE_{i}' for i in range(circ.num_gates)], legend=False)
../_images/b4e8c1bbfb181115e7d89d6be47daa638c1a6d0b979a9530dfe137f7dbf4cf9c.svg

10.2. The Target Unitary

Next we need a target unitary to try and digitially replicate. Here we’ll take an Ising Hamiltonian and a short time evolution. Once we have the dense (matrix) form of the target unitary $U$ we need to convert it to a tensor which we can put in a tensor network:

# the hamiltonian
H = qu.ham_ising(n, jz=1.0, bx=0.7, cyclic=False)

# the propagator for the hamiltonian
t = 2
U_dense = qu.expm(-1j * t * H)

# 'tensorized' version of the unitary propagator
U = qtn.Tensor(
    data=U_dense.reshape([2] * (2 * n)),
    inds=[f'k{i}' for i in range(n)] + [f'b{i}' for i in range(n)],
    tags={'U_TARGET'}
)
U.draw(color=['U3', gate2, 'U_TARGET'])
../_images/57def8e9747a9bc5046612ba4072d81c7354d8a8b8cd92007a322a58ab6a6783.svg

The core object describing how similar two unitaries are is: \(\mathrm{Tr}(V^{\dagger}U)\), which we can naturally visualize at a tensor network:

(V.H & U).draw(color=['U3', gate2, 'U_TARGET'])
../_images/8c14eb3cd6c5340c19a0c714c1f8ecb8f6b7da9036af8e0b86ff20e1e79da81b.svg

For our loss function we’ll normalize this and negate it (since the optimizer minimizes).

def loss(V, U):
    return 1 - abs((V.H & U).contract(all, optimize='auto-hq')) / 2**n

# check our current unitary 'infidelity':
loss(V, U)
0.9881326237593226

So as expected currently the two unitaries are not similar at all.

10.3. The Tensor Network Optimization

Now we are ready to construct the TNOptimizer object, with options detailed below:

# use the autograd/jax based optimizer

tnopt = qtn.TNOptimizer(
    V,                        # the tensor network we want to optimize
    loss,                     # the function we want to minimize
    loss_constants={'U': U},  # supply U to the loss function as a constant TN
    tags=['U3'],              # only optimize U3 tensors
    autodiff_backend='jax',   # use 'autograd' for non-compiled optimization
    optimizer='L-BFGS-B',     # the optimization algorithm
)

Note

If tags is not specified the default is to optimize all tensors. In that case, instead of specifying tensor tags to opt-in you can use constant_tags to opt-out tensors tags you don’t want to optimize, which may be more convenient.

We could call optimize for pure gradient based optimization, but since unitary circuits can be tricky we’ll use optimize_basinhopping which combines gradient descent with ‘hopping’ to escape local minima:

# allow 10 hops with 500 steps in each 'basin'
V_opt = tnopt.optimize_basinhopping(n=500, nhop=10)
+0.005301594734 [best: +0.005261898041] :  45%|████▌     | 2258/5000 [00:40<00:49, 55.26it/s]

The optimized tensor network still contains PTensor instances but now with optimized parameters. For example, here’s the tensor of the U3 gate acting on qubit-2 in round-4:

V_opt['U3', 'I2', 'ROUND_4']
PTensor(shape=(2, 2), inds=('_8d3684AAABr', '_8d3684AAABh'), tags=oset(['GATE_46', 'ROUND_4', 'U3', 'I2']))

We can see the parameters have been updated by the training:

# the initial values
V['U3', 'ROUND_4', 'I2'].params
array([0.25927184, 0.10598236, 0.42593174])
# the optimized values
V_opt['U3', 'ROUND_4', 'I2'].params
array([ 0.52709323,  0.6036498 , -0.23340225], dtype=float32)

We can see what gate these parameters would generate:

qu.U_gate(*V_opt['U3', 'ROUND_4', 'I2'].params)
[[ 0.965472-0.j       -0.253443+0.060252j]
 [ 0.214467+0.147877j  0.90005 +0.349352j]]

A final sanity check we can perform is to try evolving a random state with the target unitary and trained circuit and check the fidelity between the resulting states.

First we turn the tensor network version of \(V\) into a dense matrix:

V_opt_dense = V_opt.to_dense([f'k{i}' for i in range(n)], [f'b{i}' for i in range(n)])

Next we create a random initial state, and evolve it with the

psi0 = qu.rand_ket(2**n)

# this is the exact state we want
psif_exact = U_dense @ psi0

# this is the state our circuit will produce if fed `psi0`
psif_apprx = V_opt_dense @ psi0

The (in)fidelity should broadly match our training loss:

f"Fidelity: {100 * qu.fidelity(psif_apprx, psif_exact):.2f} %"
'Fidelity: 99.57 %'

10.4. Extracting the New Circuit

We can extract the trained circuit parameters by updating the original Circuit object from the trained TN:

circ.update_params_from(V_opt)

# the basic gate specification
circ.gates
[<Gate(label=U3, params=(0.16929962, -0.6534094, -0.25308686), qubits=(0,), round=0, parametrize=True))>,
 <Gate(label=U3, params=(-0.8037065, -0.25300127, 1.5709982), qubits=(1,), round=0, parametrize=True))>,
 <Gate(label=U3, params=(-1.0775858, 0.29377267, 1.573218), qubits=(2,), round=0, parametrize=True))>,
 <Gate(label=U3, params=(-0.70041966, -0.54078287, 1.5730082), qubits=(3,), round=0, parametrize=True))>,
 <Gate(label=U3, params=(2.233548, -0.6997934, 1.5887524), qubits=(4,), round=0, parametrize=True))>,
 <Gate(label=U3, params=(2.6693592, -0.3773012, 1.6181816), qubits=(5,), round=0, parametrize=True))>,
 <Gate(label=CZ, params=(), qubits=(4, 5), round=0)>,
 <Gate(label=CZ, params=(), qubits=(3, 4), round=0)>,
 <Gate(label=CZ, params=(), qubits=(2, 3), round=0)>,
 <Gate(label=CZ, params=(), qubits=(1, 2), round=0)>,
 <Gate(label=CZ, params=(), qubits=(0, 1), round=0)>,
 <Gate(label=U3, params=(0.9794985, -0.1482388, 0.3628862), qubits=(0,), round=1, parametrize=True))>,
 <Gate(label=U3, params=(1.3343341, -0.8915439, 0.252669), qubits=(1,), round=1, parametrize=True))>,
 <Gate(label=U3, params=(0.0016000561, 1.3208843, 1.8358507), qubits=(2,), round=1, parametrize=True))>,
 <Gate(label=U3, params=(0.20370977, -0.5142169, 0.52689326), qubits=(3,), round=1, parametrize=True))>,
 <Gate(label=U3, params=(1.7909783, 0.082318805, 0.6477436), qubits=(4,), round=1, parametrize=True))>,
 <Gate(label=U3, params=(-1.6923836, 0.22169495, -0.06101191), qubits=(5,), round=1, parametrize=True))>,
 <Gate(label=CZ, params=(), qubits=(0, 1), round=1)>,
 <Gate(label=CZ, params=(), qubits=(1, 2), round=1)>,
 <Gate(label=CZ, params=(), qubits=(2, 3), round=1)>,
 <Gate(label=CZ, params=(), qubits=(3, 4), round=1)>,
 <Gate(label=CZ, params=(), qubits=(4, 5), round=1)>,
 <Gate(label=U3, params=(-1.5676986, 0.79075044, 1.9245409), qubits=(0,), round=2, parametrize=True))>,
 <Gate(label=U3, params=(1.4227079, 0.50024813, 0.8913113), qubits=(1,), round=2, parametrize=True))>,
 <Gate(label=U3, params=(-1.5313544, 0.5816354, -0.30510747), qubits=(2,), round=2, parametrize=True))>,
 <Gate(label=U3, params=(0.05509888, 0.46960974, 0.5104008), qubits=(3,), round=2, parametrize=True))>,
 <Gate(label=U3, params=(-1.4967402, 2.4865503, -0.09552598), qubits=(4,), round=2, parametrize=True))>,
 <Gate(label=U3, params=(1.6648741, 1.5529978, 0.33650705), qubits=(5,), round=2, parametrize=True))>,
 <Gate(label=CZ, params=(), qubits=(4, 5), round=2)>,
 <Gate(label=CZ, params=(), qubits=(3, 4), round=2)>,
 <Gate(label=CZ, params=(), qubits=(2, 3), round=2)>,
 <Gate(label=CZ, params=(), qubits=(1, 2), round=2)>,
 <Gate(label=CZ, params=(), qubits=(0, 1), round=2)>,
 <Gate(label=U3, params=(2.3430626, 0.5120296, 0.76947355), qubits=(0,), round=3, parametrize=True))>,
 <Gate(label=U3, params=(1.4037384, 0.34615904, 2.6413994), qubits=(1,), round=3, parametrize=True))>,
 <Gate(label=U3, params=(-1.4830327, 0.22897828, -0.58199924), qubits=(2,), round=3, parametrize=True))>,
 <Gate(label=U3, params=(0.002939354, 0.941026, 2.3286314), qubits=(3,), round=3, parametrize=True))>,
 <Gate(label=U3, params=(-0.19347823, 1.4644276, 0.6541145), qubits=(4,), round=3, parametrize=True))>,
 <Gate(label=U3, params=(0.79052246, 1.3951869, -0.1806278), qubits=(5,), round=3, parametrize=True))>,
 <Gate(label=CZ, params=(), qubits=(0, 1), round=3)>,
 <Gate(label=CZ, params=(), qubits=(1, 2), round=3)>,
 <Gate(label=CZ, params=(), qubits=(2, 3), round=3)>,
 <Gate(label=CZ, params=(), qubits=(3, 4), round=3)>,
 <Gate(label=CZ, params=(), qubits=(4, 5), round=3)>,
 <Gate(label=U3, params=(1.7128671, -0.83896357, 0.87925446), qubits=(0,), round=4, parametrize=True))>,
 <Gate(label=U3, params=(1.7145393, -0.38567403, -0.34633058), qubits=(1,), round=4, parametrize=True))>,
 <Gate(label=U3, params=(0.52709323, 0.6036498, -0.23340225), qubits=(2,), round=4, parametrize=True))>,
 <Gate(label=U3, params=(0.005033329, -0.7817636, -0.5969079), qubits=(3,), round=4, parametrize=True))>,
 <Gate(label=U3, params=(2.1565654, 0.751235, -1.485796), qubits=(4,), round=4, parametrize=True))>,
 <Gate(label=U3, params=(1.5731558, 1.591097, 0.31102985), qubits=(5,), round=4, parametrize=True))>,
 <Gate(label=CZ, params=(), qubits=(4, 5), round=4)>,
 <Gate(label=CZ, params=(), qubits=(3, 4), round=4)>,
 <Gate(label=CZ, params=(), qubits=(2, 3), round=4)>,
 <Gate(label=CZ, params=(), qubits=(1, 2), round=4)>,
 <Gate(label=CZ, params=(), qubits=(0, 1), round=4)>,
 <Gate(label=U3, params=(-2.325529, 0.418274, -1.1409844), qubits=(0,), round=5, parametrize=True))>,
 <Gate(label=U3, params=(0.011331396, -0.6176183, 0.28041682), qubits=(1,), round=5, parametrize=True))>,
 <Gate(label=U3, params=(1.4015418, 0.5063352, -0.5949921), qubits=(2,), round=5, parametrize=True))>,
 <Gate(label=U3, params=(0.009612298, 0.2217627, 0.04534868), qubits=(3,), round=5, parametrize=True))>,
 <Gate(label=U3, params=(1.8638043, -0.20229527, -0.7953688), qubits=(4,), round=5, parametrize=True))>,
 <Gate(label=U3, params=(-0.02179846, -0.57399744, 0.8407001), qubits=(5,), round=5, parametrize=True))>,
 <Gate(label=CZ, params=(), qubits=(0, 1), round=5)>,
 <Gate(label=CZ, params=(), qubits=(1, 2), round=5)>,
 <Gate(label=CZ, params=(), qubits=(2, 3), round=5)>,
 <Gate(label=CZ, params=(), qubits=(3, 4), round=5)>,
 <Gate(label=CZ, params=(), qubits=(4, 5), round=5)>,
 <Gate(label=U3, params=(1.3877774, 0.9551876, 0.9025401), qubits=(0,), round=6, parametrize=True))>,
 <Gate(label=U3, params=(0.002978851, 2.6452973, 0.20050335), qubits=(1,), round=6, parametrize=True))>,
 <Gate(label=U3, params=(0.44300193, -0.30685717, -0.5065869), qubits=(2,), round=6, parametrize=True))>,
 <Gate(label=U3, params=(0.076909065, 1.6145856, 0.6035761), qubits=(3,), round=6, parametrize=True))>,
 <Gate(label=U3, params=(-0.5196627, 1.6240075, 0.20132811), qubits=(4,), round=6, parametrize=True))>,
 <Gate(label=U3, params=(2.2904618, 0.9574056, -0.4791421), qubits=(5,), round=6, parametrize=True))>,
 <Gate(label=CZ, params=(), qubits=(4, 5), round=6)>,
 <Gate(label=CZ, params=(), qubits=(3, 4), round=6)>,
 <Gate(label=CZ, params=(), qubits=(2, 3), round=6)>,
 <Gate(label=CZ, params=(), qubits=(1, 2), round=6)>,
 <Gate(label=CZ, params=(), qubits=(0, 1), round=6)>,
 <Gate(label=U3, params=(-0.30187988, -0.87076926, -0.8777541), qubits=(0,), round=7, parametrize=True))>,
 <Gate(label=U3, params=(1.624832, 1.3061262, 1.019058), qubits=(1,), round=7, parametrize=True))>,
 <Gate(label=U3, params=(-0.0049230885, 1.0577729, 1.851811), qubits=(2,), round=7, parametrize=True))>,
 <Gate(label=U3, params=(0.02848778, 1.4706283, 1.2345053), qubits=(3,), round=7, parametrize=True))>,
 <Gate(label=U3, params=(0.2926683, -0.228561, 0.6961813), qubits=(4,), round=7, parametrize=True))>,
 <Gate(label=U3, params=(-0.64344996, -0.6371593, -1.0523038), qubits=(5,), round=7, parametrize=True))>,
 <Gate(label=CZ, params=(), qubits=(0, 1), round=7)>,
 <Gate(label=CZ, params=(), qubits=(1, 2), round=7)>,
 <Gate(label=CZ, params=(), qubits=(2, 3), round=7)>,
 <Gate(label=CZ, params=(), qubits=(3, 4), round=7)>,
 <Gate(label=CZ, params=(), qubits=(4, 5), round=7)>,
 <Gate(label=U3, params=(-1.3433273, 1.4197886, 0.5518878), qubits=(0,), round=8, parametrize=True))>,
 <Gate(label=U3, params=(-1.7108984, -0.03164328, -1.3058407), qubits=(1,), round=8, parametrize=True))>,
 <Gate(label=U3, params=(-1.4684929, -0.23041597, 0.5393066), qubits=(2,), round=8, parametrize=True))>,
 <Gate(label=U3, params=(0.008398809, -0.18387741, -0.6591129), qubits=(3,), round=8, parametrize=True))>,
 <Gate(label=U3, params=(-0.000680109, 0.2461698, -0.4091857), qubits=(4,), round=8, parametrize=True))>,
 <Gate(label=U3, params=(0.110012054, 1.3755561, 1.2930975), qubits=(5,), round=8, parametrize=True))>,
 <Gate(label=CZ, params=(), qubits=(4, 5), round=8)>,
 <Gate(label=CZ, params=(), qubits=(3, 4), round=8)>,
 <Gate(label=CZ, params=(), qubits=(2, 3), round=8)>,
 <Gate(label=CZ, params=(), qubits=(1, 2), round=8)>,
 <Gate(label=CZ, params=(), qubits=(0, 1), round=8)>,
 <Gate(label=U3, params=(0.4610416, 0.26438755, 1.6497571), qubits=(0,), round=9, parametrize=True))>,
 <Gate(label=U3, params=(-0.7864091, 1.571062, 0.03173134), qubits=(1,), round=9, parametrize=True))>,
 <Gate(label=U3, params=(-1.0439473, 1.5646937, 0.2356552), qubits=(2,), round=9, parametrize=True))>,
 <Gate(label=U3, params=(0.88693166, 1.5739466, -0.4111447), qubits=(3,), round=9, parametrize=True))>,
 <Gate(label=U3, params=(0.4123549, 1.0429444, 1.6132663), qubits=(4,), round=9, parametrize=True))>,
 <Gate(label=U3, params=(1.3669215, 3.024821, 0.9148983), qubits=(5,), round=9, parametrize=True))>]