10. Tensor Network Training of Quantum Circuits¶
Here we’ll run through constructing a tensor network of an ansatz quantum circuit, then training certain ‘parametrizable’ tensors representing quantum gates in that tensor network to replicate the behaviour of a target unitary.
%config InlineBackend.figure_formats = ['svg']
import quimb as qu
import quimb.tensor as qtn
10.1. The Ansatz Circuit¶
First we set up the ansatz circuit and extract the tensor network. Key here is that when we supply parametrize=True
to the 'U3'
gate call, it injects a PTensor
into the network, which lazily represents its data array with a function and set of parameters. Later, when the optimizer
sees this it then knows to optimize the parameters rather than the array itself.
def single_qubit_layer(circ, gate_round=None):
"""Apply a parametrizable layer of single qubit ``U3`` gates.
"""
for i in range(circ.N):
# initialize with random parameters
params = qu.randn(3, dist='uniform')
circ.apply_gate(
'U3', *params, i,
gate_round=gate_round, parametrize=True)
def two_qubit_layer(circ, gate2='CZ', reverse=False, gate_round=None):
"""Apply a layer of constant entangling gates.
"""
regs = range(0, circ.N - 1)
if reverse:
regs = reversed(regs)
for i in regs:
circ.apply_gate(
gate2, i, i + 1, gate_round=gate_round)
def ansatz_circuit(n, depth, gate2='CZ', **kwargs):
"""Construct a circuit of single qubit and entangling layers.
"""
circ = qtn.Circuit(n, **kwargs)
for r in range(depth):
# single qubit gate layer
single_qubit_layer(circ, gate_round=r)
# alternate between forward and backward CZ layers
two_qubit_layer(
circ, gate2=gate2, gate_round=r, reverse=r % 2 == 0)
# add a final single qubit layer
single_qubit_layer(circ, gate_round=r + 1)
return circ
The form of the 'U3'
gate (which generalizes all possible single qubit gates) can be seen here - U_gate()
. Now we are ready to instantiate a circuit:
n = 6
depth = 9
gate2 = 'CZ'
circ = ansatz_circuit(n, depth, gate2=gate2)
circ
<Circuit(n=6, num_gates=105, gate_opts={'contract': 'auto-split-gate', 'propagate_tags': 'register'})>
We can extract just the unitary part of the circuit as a tensor network like so:
V = circ.uni
/media/johnnie/Storage2TB/Sync/dev/python/quimb/quimb/tensor/circuit.py:1194: FutureWarning: In future the tensor network returned by ``circ.uni`` will not be transposed as it is currently, to match the expectation from ``U = circ.uni.to_dense()`` behaving like ``U @ psi``. You can retain this behaviour with ``circ.get_uni(transposed=True)``.
warnings.warn(
You can see it already has various tags
simultaneously identifying different structures:
# types of gate
V.draw(color=['U3', gate2], show_inds=True)
# layers of gates
V.draw(color=[f'ROUND_{i}' for i in range(depth + 1)], show_inds=True)
# what register each tensor is 'above'
V.draw(color=[f'I{i}' for i in range(n)], show_inds=True)
# a unique tag for per gate applied
V.draw(color=[f'GATE_{i}' for i in range(circ.num_gates)], legend=False)
10.2. The Target Unitary¶
Next we need a target unitary to try and digitially replicate. Here we’ll take an Ising Hamiltonian and a short time evolution. Once we have the dense (matrix) form of the target unitary $U$ we need to convert it to a tensor which we can put in a tensor network:
# the hamiltonian
H = qu.ham_ising(n, jz=1.0, bx=0.7, cyclic=False)
# the propagator for the hamiltonian
t = 2
U_dense = qu.expm(-1j * t * H)
# 'tensorized' version of the unitary propagator
U = qtn.Tensor(
data=U_dense.reshape([2] * (2 * n)),
inds=[f'k{i}' for i in range(n)] + [f'b{i}' for i in range(n)],
tags={'U_TARGET'}
)
U.draw(color=['U3', gate2, 'U_TARGET'])
The core object describing how similar two unitaries are is: \(\mathrm{Tr}(V^{\dagger}U)\), which we can naturally visualize at a tensor network:
(V.H & U).draw(color=['U3', gate2, 'U_TARGET'])
For our loss function we’ll normalize this and negate it (since the optimizer minimizes).
def loss(V, U):
return 1 - abs((V.H & U).contract(all, optimize='auto-hq')) / 2**n
# check our current unitary 'infidelity':
loss(V, U)
0.9881326237593226
So as expected currently the two unitaries are not similar at all.
10.3. The Tensor Network Optimization¶
Now we are ready to construct the TNOptimizer
object, with options detailed below:
# use the autograd/jax based optimizer
tnopt = qtn.TNOptimizer(
V, # the tensor network we want to optimize
loss, # the function we want to minimize
loss_constants={'U': U}, # supply U to the loss function as a constant TN
tags=['U3'], # only optimize U3 tensors
autodiff_backend='jax', # use 'autograd' for non-compiled optimization
optimizer='L-BFGS-B', # the optimization algorithm
)
Note
If tags
is not specified the default is to optimize all tensors.
In that case, instead of specifying tensor tags to opt-in you can use
constant_tags
to opt-out tensors tags you don’t want to optimize,
which may be more convenient.
We could call optimize
for pure gradient based optimization, but since
unitary circuits can be tricky we’ll use optimize_basinhopping
which
combines gradient descent with ‘hopping’ to escape local minima:
# allow 10 hops with 500 steps in each 'basin'
V_opt = tnopt.optimize_basinhopping(n=500, nhop=10)
+0.005301594734 [best: +0.005261898041] : 45%|████▌ | 2258/5000 [00:40<00:49, 55.26it/s]
The optimized tensor network still contains PTensor
instances but now with optimized parameters.
For example, here’s the tensor of the U3
gate acting on qubit-2 in round-4:
V_opt['U3', 'I2', 'ROUND_4']
PTensor(shape=(2, 2), inds=('_8d3684AAABr', '_8d3684AAABh'), tags=oset(['GATE_46', 'ROUND_4', 'U3', 'I2']))
We can see the parameters have been updated by the training:
# the initial values
V['U3', 'ROUND_4', 'I2'].params
array([0.25927184, 0.10598236, 0.42593174])
# the optimized values
V_opt['U3', 'ROUND_4', 'I2'].params
array([ 0.52709323, 0.6036498 , -0.23340225], dtype=float32)
We can see what gate these parameters would generate:
qu.U_gate(*V_opt['U3', 'ROUND_4', 'I2'].params)
[[ 0.965472-0.j -0.253443+0.060252j]
[ 0.214467+0.147877j 0.90005 +0.349352j]]
A final sanity check we can perform is to try evolving a random state with the target unitary and trained circuit and check the fidelity between the resulting states.
First we turn the tensor network version of \(V\) into a dense matrix:
V_opt_dense = V_opt.to_dense([f'k{i}' for i in range(n)], [f'b{i}' for i in range(n)])
Next we create a random initial state, and evolve it with the
psi0 = qu.rand_ket(2**n)
# this is the exact state we want
psif_exact = U_dense @ psi0
# this is the state our circuit will produce if fed `psi0`
psif_apprx = V_opt_dense @ psi0
The (in)fidelity should broadly match our training loss:
f"Fidelity: {100 * qu.fidelity(psif_apprx, psif_exact):.2f} %"
'Fidelity: 99.57 %'
10.4. Extracting the New Circuit¶
We can extract the trained circuit parameters by updating the original
Circuit
object from the trained TN:
circ.update_params_from(V_opt)
# the basic gate specification
circ.gates
[<Gate(label=U3, params=(0.16929962, -0.6534094, -0.25308686), qubits=(0,), round=0, parametrize=True))>,
<Gate(label=U3, params=(-0.8037065, -0.25300127, 1.5709982), qubits=(1,), round=0, parametrize=True))>,
<Gate(label=U3, params=(-1.0775858, 0.29377267, 1.573218), qubits=(2,), round=0, parametrize=True))>,
<Gate(label=U3, params=(-0.70041966, -0.54078287, 1.5730082), qubits=(3,), round=0, parametrize=True))>,
<Gate(label=U3, params=(2.233548, -0.6997934, 1.5887524), qubits=(4,), round=0, parametrize=True))>,
<Gate(label=U3, params=(2.6693592, -0.3773012, 1.6181816), qubits=(5,), round=0, parametrize=True))>,
<Gate(label=CZ, params=(), qubits=(4, 5), round=0)>,
<Gate(label=CZ, params=(), qubits=(3, 4), round=0)>,
<Gate(label=CZ, params=(), qubits=(2, 3), round=0)>,
<Gate(label=CZ, params=(), qubits=(1, 2), round=0)>,
<Gate(label=CZ, params=(), qubits=(0, 1), round=0)>,
<Gate(label=U3, params=(0.9794985, -0.1482388, 0.3628862), qubits=(0,), round=1, parametrize=True))>,
<Gate(label=U3, params=(1.3343341, -0.8915439, 0.252669), qubits=(1,), round=1, parametrize=True))>,
<Gate(label=U3, params=(0.0016000561, 1.3208843, 1.8358507), qubits=(2,), round=1, parametrize=True))>,
<Gate(label=U3, params=(0.20370977, -0.5142169, 0.52689326), qubits=(3,), round=1, parametrize=True))>,
<Gate(label=U3, params=(1.7909783, 0.082318805, 0.6477436), qubits=(4,), round=1, parametrize=True))>,
<Gate(label=U3, params=(-1.6923836, 0.22169495, -0.06101191), qubits=(5,), round=1, parametrize=True))>,
<Gate(label=CZ, params=(), qubits=(0, 1), round=1)>,
<Gate(label=CZ, params=(), qubits=(1, 2), round=1)>,
<Gate(label=CZ, params=(), qubits=(2, 3), round=1)>,
<Gate(label=CZ, params=(), qubits=(3, 4), round=1)>,
<Gate(label=CZ, params=(), qubits=(4, 5), round=1)>,
<Gate(label=U3, params=(-1.5676986, 0.79075044, 1.9245409), qubits=(0,), round=2, parametrize=True))>,
<Gate(label=U3, params=(1.4227079, 0.50024813, 0.8913113), qubits=(1,), round=2, parametrize=True))>,
<Gate(label=U3, params=(-1.5313544, 0.5816354, -0.30510747), qubits=(2,), round=2, parametrize=True))>,
<Gate(label=U3, params=(0.05509888, 0.46960974, 0.5104008), qubits=(3,), round=2, parametrize=True))>,
<Gate(label=U3, params=(-1.4967402, 2.4865503, -0.09552598), qubits=(4,), round=2, parametrize=True))>,
<Gate(label=U3, params=(1.6648741, 1.5529978, 0.33650705), qubits=(5,), round=2, parametrize=True))>,
<Gate(label=CZ, params=(), qubits=(4, 5), round=2)>,
<Gate(label=CZ, params=(), qubits=(3, 4), round=2)>,
<Gate(label=CZ, params=(), qubits=(2, 3), round=2)>,
<Gate(label=CZ, params=(), qubits=(1, 2), round=2)>,
<Gate(label=CZ, params=(), qubits=(0, 1), round=2)>,
<Gate(label=U3, params=(2.3430626, 0.5120296, 0.76947355), qubits=(0,), round=3, parametrize=True))>,
<Gate(label=U3, params=(1.4037384, 0.34615904, 2.6413994), qubits=(1,), round=3, parametrize=True))>,
<Gate(label=U3, params=(-1.4830327, 0.22897828, -0.58199924), qubits=(2,), round=3, parametrize=True))>,
<Gate(label=U3, params=(0.002939354, 0.941026, 2.3286314), qubits=(3,), round=3, parametrize=True))>,
<Gate(label=U3, params=(-0.19347823, 1.4644276, 0.6541145), qubits=(4,), round=3, parametrize=True))>,
<Gate(label=U3, params=(0.79052246, 1.3951869, -0.1806278), qubits=(5,), round=3, parametrize=True))>,
<Gate(label=CZ, params=(), qubits=(0, 1), round=3)>,
<Gate(label=CZ, params=(), qubits=(1, 2), round=3)>,
<Gate(label=CZ, params=(), qubits=(2, 3), round=3)>,
<Gate(label=CZ, params=(), qubits=(3, 4), round=3)>,
<Gate(label=CZ, params=(), qubits=(4, 5), round=3)>,
<Gate(label=U3, params=(1.7128671, -0.83896357, 0.87925446), qubits=(0,), round=4, parametrize=True))>,
<Gate(label=U3, params=(1.7145393, -0.38567403, -0.34633058), qubits=(1,), round=4, parametrize=True))>,
<Gate(label=U3, params=(0.52709323, 0.6036498, -0.23340225), qubits=(2,), round=4, parametrize=True))>,
<Gate(label=U3, params=(0.005033329, -0.7817636, -0.5969079), qubits=(3,), round=4, parametrize=True))>,
<Gate(label=U3, params=(2.1565654, 0.751235, -1.485796), qubits=(4,), round=4, parametrize=True))>,
<Gate(label=U3, params=(1.5731558, 1.591097, 0.31102985), qubits=(5,), round=4, parametrize=True))>,
<Gate(label=CZ, params=(), qubits=(4, 5), round=4)>,
<Gate(label=CZ, params=(), qubits=(3, 4), round=4)>,
<Gate(label=CZ, params=(), qubits=(2, 3), round=4)>,
<Gate(label=CZ, params=(), qubits=(1, 2), round=4)>,
<Gate(label=CZ, params=(), qubits=(0, 1), round=4)>,
<Gate(label=U3, params=(-2.325529, 0.418274, -1.1409844), qubits=(0,), round=5, parametrize=True))>,
<Gate(label=U3, params=(0.011331396, -0.6176183, 0.28041682), qubits=(1,), round=5, parametrize=True))>,
<Gate(label=U3, params=(1.4015418, 0.5063352, -0.5949921), qubits=(2,), round=5, parametrize=True))>,
<Gate(label=U3, params=(0.009612298, 0.2217627, 0.04534868), qubits=(3,), round=5, parametrize=True))>,
<Gate(label=U3, params=(1.8638043, -0.20229527, -0.7953688), qubits=(4,), round=5, parametrize=True))>,
<Gate(label=U3, params=(-0.02179846, -0.57399744, 0.8407001), qubits=(5,), round=5, parametrize=True))>,
<Gate(label=CZ, params=(), qubits=(0, 1), round=5)>,
<Gate(label=CZ, params=(), qubits=(1, 2), round=5)>,
<Gate(label=CZ, params=(), qubits=(2, 3), round=5)>,
<Gate(label=CZ, params=(), qubits=(3, 4), round=5)>,
<Gate(label=CZ, params=(), qubits=(4, 5), round=5)>,
<Gate(label=U3, params=(1.3877774, 0.9551876, 0.9025401), qubits=(0,), round=6, parametrize=True))>,
<Gate(label=U3, params=(0.002978851, 2.6452973, 0.20050335), qubits=(1,), round=6, parametrize=True))>,
<Gate(label=U3, params=(0.44300193, -0.30685717, -0.5065869), qubits=(2,), round=6, parametrize=True))>,
<Gate(label=U3, params=(0.076909065, 1.6145856, 0.6035761), qubits=(3,), round=6, parametrize=True))>,
<Gate(label=U3, params=(-0.5196627, 1.6240075, 0.20132811), qubits=(4,), round=6, parametrize=True))>,
<Gate(label=U3, params=(2.2904618, 0.9574056, -0.4791421), qubits=(5,), round=6, parametrize=True))>,
<Gate(label=CZ, params=(), qubits=(4, 5), round=6)>,
<Gate(label=CZ, params=(), qubits=(3, 4), round=6)>,
<Gate(label=CZ, params=(), qubits=(2, 3), round=6)>,
<Gate(label=CZ, params=(), qubits=(1, 2), round=6)>,
<Gate(label=CZ, params=(), qubits=(0, 1), round=6)>,
<Gate(label=U3, params=(-0.30187988, -0.87076926, -0.8777541), qubits=(0,), round=7, parametrize=True))>,
<Gate(label=U3, params=(1.624832, 1.3061262, 1.019058), qubits=(1,), round=7, parametrize=True))>,
<Gate(label=U3, params=(-0.0049230885, 1.0577729, 1.851811), qubits=(2,), round=7, parametrize=True))>,
<Gate(label=U3, params=(0.02848778, 1.4706283, 1.2345053), qubits=(3,), round=7, parametrize=True))>,
<Gate(label=U3, params=(0.2926683, -0.228561, 0.6961813), qubits=(4,), round=7, parametrize=True))>,
<Gate(label=U3, params=(-0.64344996, -0.6371593, -1.0523038), qubits=(5,), round=7, parametrize=True))>,
<Gate(label=CZ, params=(), qubits=(0, 1), round=7)>,
<Gate(label=CZ, params=(), qubits=(1, 2), round=7)>,
<Gate(label=CZ, params=(), qubits=(2, 3), round=7)>,
<Gate(label=CZ, params=(), qubits=(3, 4), round=7)>,
<Gate(label=CZ, params=(), qubits=(4, 5), round=7)>,
<Gate(label=U3, params=(-1.3433273, 1.4197886, 0.5518878), qubits=(0,), round=8, parametrize=True))>,
<Gate(label=U3, params=(-1.7108984, -0.03164328, -1.3058407), qubits=(1,), round=8, parametrize=True))>,
<Gate(label=U3, params=(-1.4684929, -0.23041597, 0.5393066), qubits=(2,), round=8, parametrize=True))>,
<Gate(label=U3, params=(0.008398809, -0.18387741, -0.6591129), qubits=(3,), round=8, parametrize=True))>,
<Gate(label=U3, params=(-0.000680109, 0.2461698, -0.4091857), qubits=(4,), round=8, parametrize=True))>,
<Gate(label=U3, params=(0.110012054, 1.3755561, 1.2930975), qubits=(5,), round=8, parametrize=True))>,
<Gate(label=CZ, params=(), qubits=(4, 5), round=8)>,
<Gate(label=CZ, params=(), qubits=(3, 4), round=8)>,
<Gate(label=CZ, params=(), qubits=(2, 3), round=8)>,
<Gate(label=CZ, params=(), qubits=(1, 2), round=8)>,
<Gate(label=CZ, params=(), qubits=(0, 1), round=8)>,
<Gate(label=U3, params=(0.4610416, 0.26438755, 1.6497571), qubits=(0,), round=9, parametrize=True))>,
<Gate(label=U3, params=(-0.7864091, 1.571062, 0.03173134), qubits=(1,), round=9, parametrize=True))>,
<Gate(label=U3, params=(-1.0439473, 1.5646937, 0.2356552), qubits=(2,), round=9, parametrize=True))>,
<Gate(label=U3, params=(0.88693166, 1.5739466, -0.4111447), qubits=(3,), round=9, parametrize=True))>,
<Gate(label=U3, params=(0.4123549, 1.0429444, 1.6132663), qubits=(4,), round=9, parametrize=True))>,
<Gate(label=U3, params=(1.3669215, 3.024821, 0.9148983), qubits=(5,), round=9, parametrize=True))>]