{ "cells": [ { "cell_type": "markdown", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "(ex-tensorflow-tn-opt)=\n", "\n", "# Optimizing a Tensor Network using Tensorflow\n", "\n", "In this example we show how a general machine learning\n", "strategy can be used to optimize arbitrary tensor networks\n", "with respect to any target loss function.\n", "\n", "We'll take the example of maximizing the overlap of some\n", "matrix product state with periodic boundary conditions\n", "with a densely represented state, since this does not\n", "have a simple, deterministic alternative.\n", "\n", "`quimb` makes use of `cotengra` which can contract\n", "tensors with a variety of backends as well as `autoray`\n", "for handling array operations agnostically. Here we'll use\n", "`tensorflow-v2` for the actual auto-gradient computation." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%config InlineBackend.figure_formats = ['svg']\n", "\n", "import quimb as qu\n", "import quimb.tensor as qtn\n", "from quimb.tensor.optimize import TNOptimizer\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, find a (dense) PBC groundstate, $| gs \\rangle$:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "L = 16\n", "H = qu.ham_heis(L, sparse=True, cyclic=True)\n", "gs = qu.groundstate(H)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then we convert it to a dense 1D 'tensor network':" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dense1D([\n", " Tensor(shape=(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2), inds=('k0', 'k1', 'k2', 'k3', 'k4', 'k5', 'k6', 'k7', 'k8', 'k9', 'k10', 'k11', 'k12', 'k13', 'k14', 'k15'), tags=oset(['I0', 'I1', 'I2', 'I3', 'I4', 'I5', 'I6', 'I7', 'I8', 'I9', 'I10', 'I11', 'I12', 'I13', 'I14', 'I15'])),\n", "], tensors=1, indices=16, L=16, max_bond=2)\n" ] } ], "source": [ "# this converts the dense vector to an effective 1D tensor network (with only one tensor)\n", "target = qtn.Dense1D(gs)\n", "print(target)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next we create an initial guess random MPS, $|\\psi\\rangle$:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-13T15:15:47.911909\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "bond_dim = 32\n", "mps = qtn.MPS_rand_state(L, bond_dim, cyclic=True)\n", "mps.draw()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We now need to set-up the function that 'prepares' our tensor network. \n", "In the current example this involves making sure the state is always normalized." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def normalize_state(psi):\n", " return psi / (psi.H @ psi) ** 0.5\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then we need to set-up our 'loss' function, the function that returns \n", "the scalar quantity we want to minimize." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "def negative_overlap(psi, target):\n", " return - (psi.H @ target) ** 2 # minus so as to minimize\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can set up the tensor network optimizer object:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-10-13 15:16:10.049837: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", "2023-10-13 15:16:10.851994: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n", "2023-10-13 15:16:11.991099: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:995] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355\n", "2023-10-13 15:16:12.047657: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1960] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.\n", "Skipping registering GPU devices...\n" ] } ], "source": [ "optmzr = TNOptimizer(\n", " mps, # our initial input, the tensors of which to optimize\n", " loss_fn=negative_overlap,\n", " norm_fn=normalize_state,\n", " loss_constants={'target': target}, # this is a constant TN to supply to loss_fn\n", " autodiff_backend='tensorflow', # {'jax', 'tensorflow', 'autograd'}\n", " optimizer='L-BFGS-B', # supplied to scipy.minimize\n", ")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then we are ready to optimize our tensor network! Note how we supplied the constant tensor network ``target`` - its tensors will not be changed." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "-0.999723964428 [best: -0.999723964428] : : 109it [00:19, 5.48it/s] \n" ] } ], "source": [ "mps_opt = optmzr.optimize(100) # perform ~100 gradient descent steps\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The output optimized (and normalized) tensor netwwork has already been converted back to numpy:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'numpy'" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mps_opt[0].backend\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And we can explicitly check the returned state indeed matches the loss shown above:" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9997239644280398" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "((mps_opt.H & target) ^ all) ** 2\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Other things to think about might be:\n", "\n", "- try other scipy optimizers for the `optimizer=` option\n", "- try other autodiff backends for the `autodiff_backend=` option\n", " * ``'jax'`` - likely the best performance but slow to compile the initial computation\n", " * ``'autograd'`` - numpy based, cpu-only optimization\n", " * ``'torch'`` - (pytorch), quick compilation and decent performance, though no complex support (yet?)\n", "- using single precision data for better GPU acceleration\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also keep optimizing:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/100 [00:00