{ "cells": [ { "cell_type": "markdown", "id": "522c9747-df9f-4c37-924e-4192bad04e32", "metadata": {}, "source": [ "# Generic Tensor Fitting\n", "\n", "`quimb` has support for fitting arbitrary tensor networks to other tensors or tensor networks.\n", "Here we show decomposing a 4-tensor into a ring." ] }, { "cell_type": "code", "execution_count": 1, "id": "56573127-196d-4ae7-8b6d-8486df737416", "metadata": {}, "outputs": [], "source": [ "%config InlineBackend.figure_formats = ['svg']\n", "import numpy as np\n", "import quimb.tensor as qtn" ] }, { "attachments": {}, "cell_type": "markdown", "id": "c088716c", "metadata": {}, "source": [ "Create a target 10x10x10x10 tensor with uniform positive entries:" ] }, { "cell_type": "code", "execution_count": 2, "id": "aaf7d353", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Tensor(shape=(10, 10, 10, 10), inds=[a, b, c, d], tags={}),backend=numpy, dtype=float64, data=...
" ], "text/plain": [ "Tensor(shape=(10, 10, 10, 10), inds=('a', 'b', 'c', 'd'), tags=oset([]))" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t_target = qtn.Tensor(\n", " data=np.random.uniform(size=(10, 10, 10, 10)),\n", " inds=('a', 'b', 'c', 'd'),\n", ")\n", "t_target" ] }, { "cell_type": "code", "execution_count": 3, "id": "4c6b020d", "metadata": {}, "outputs": [], "source": [ "# normalize for better sense of how good the fit is\n", "t_target /= t_target.norm()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "8894bcff", "metadata": {}, "source": [ "The target could also be an arbitrary tensor network.\n", "\n", "Now we manually create the decomposed geometry, i.e. a ring of 4 tensors." ] }, { "cell_type": "code", "execution_count": 4, "id": "d9353a5c", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
TensorNetwork(tensors=4, indices=8)
Tensor(shape=(10, 5, 5), inds=[a, left, up], tags={}),backend=numpy, dtype=float64, data=...
Tensor(shape=(10, 5, 5), inds=[b, up, right], tags={}),backend=numpy, dtype=float64, data=...
Tensor(shape=(10, 5, 5), inds=[c, right, bottom], tags={}),backend=numpy, dtype=float64, data=...
Tensor(shape=(10, 5, 5), inds=[d, bottom, left], tags={}),backend=numpy, dtype=float64, data=...
" ], "text/plain": [ "TensorNetwork(tensors=4, indices=8)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rank = 5\n", "\n", "tn_guess = qtn.TensorNetwork([\n", " qtn.Tensor(np.random.normal(size=(10, rank, rank)), inds=('a', 'left', 'up')),\n", " qtn.Tensor(np.random.normal(size=(10, rank, rank)), inds=('b', 'up', 'right')),\n", " qtn.Tensor(np.random.normal(size=(10, rank, rank)), inds=('c', 'right', 'bottom')),\n", " qtn.Tensor(np.random.normal(size=(10, rank, rank)), inds=('d', 'bottom', 'left')),\n", "])\n", "tn_guess" ] }, { "attachments": {}, "cell_type": "markdown", "id": "d83ca2a2", "metadata": {}, "source": [ "We could have any internal structure, as long as the other indices match (and\n", "the contraction is possible)." ] }, { "cell_type": "code", "execution_count": 5, "id": "c96cafb3", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "tn_guess.draw(show_inds='all', highlight_inds=['a', 'b', 'c', 'd'])" ] }, { "attachments": {}, "cell_type": "markdown", "id": "f0a34b16", "metadata": {}, "source": [ "Compute the initial distance (in terms of frobeius norm):" ] }, { "cell_type": "code", "execution_count": 6, "id": "068ddd36", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "np.float64(2450.808525296577)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tn_guess.distance(t_target)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "cde20331", "metadata": {}, "source": [ "Perform the initial fitting using ALS (alternating least squares), see the\n", "function [`TensorNetwork.fit`](TensorNetwork.fit) for more details:" ] }, { "cell_type": "code", "execution_count": 7, "id": "9269e5e8", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "0.4697: 100%|██████████| 1000/1000 [00:01<00:00, 565.85it/s]\n" ] } ], "source": [ "tn_fitted = tn_guess.fit(t_target, method='als', steps=1000, progbar=True)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "49e80711", "metadata": {}, "source": [ "Sometimes, autodiff based optimization can do better than ALS, see\n", "[`TNOptimizer`](TNOptimizer) for more details:" ] }, { "cell_type": "code", "execution_count": 8, "id": "66d5011f", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "+0.457260587804 [best: +0.457260587804] : 21%|██▏ | 214/1000 [00:00<00:01, 546.10it/s]\n" ] }, { "data": { "text/html": [ "
TensorNetwork(tensors=4, indices=8)
Tensor(shape=(10, 5, 5), inds=[a, left, up], tags={}),backend=numpy, dtype=float64, data=...
Tensor(shape=(10, 5, 5), inds=[b, up, right], tags={}),backend=numpy, dtype=float64, data=...
Tensor(shape=(10, 5, 5), inds=[c, right, bottom], tags={}),backend=numpy, dtype=float64, data=...
Tensor(shape=(10, 5, 5), inds=[d, bottom, left], tags={}),backend=numpy, dtype=float64, data=...
" ], "text/plain": [ "TensorNetwork(tensors=4, indices=8)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tn_fitted.fit_(t_target, method='autodiff', steps=1000, progbar=True)" ] }, { "cell_type": "markdown", "id": "f25b486e", "metadata": {}, "source": [ "Double check the new fitted tensor network is close to the target:" ] }, { "cell_type": "code", "execution_count": 9, "id": "776da76a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "np.float64(0.45725182227607436)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tn_fitted.distance(t_target)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "b1a82a5d", "metadata": {}, "source": [ "Considering the target as a wavefunction, our fitted network has an overlap of:" ] }, { "cell_type": "code", "execution_count": 10, "id": "356e54fb", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.895422046569203" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tn_fitted @ t_target.H" ] }, { "attachments": {}, "cell_type": "markdown", "id": "e2afad2d", "metadata": {}, "source": [ "Note random tensors are generally not that easy to fit, resulting in a not\n", "great fidelity." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3" } }, "nbformat": 4, "nbformat_minor": 5 }