{
"cells": [
{
"cell_type": "markdown",
"id": "522c9747-df9f-4c37-924e-4192bad04e32",
"metadata": {},
"source": [
"# Generic Tensor Fitting\n",
"\n",
"`quimb` has support for fitting arbitrary tensor networks to other tensors or tensor networks.\n",
"Here we show decomposing a 4-tensor into a ring."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "56573127-196d-4ae7-8b6d-8486df737416",
"metadata": {},
"outputs": [],
"source": [
"%config InlineBackend.figure_formats = ['svg']\n",
"import numpy as np\n",
"import quimb.tensor as qtn"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "c088716c",
"metadata": {},
"source": [
"Create a target 10x10x10x10 tensor with uniform positive entries:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "aaf7d353",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Tensor(shape=(10, 10, 10, 10), inds=[a, b, c, d], tags={}),
backend=numpy, dtype=float64, data=... "
],
"text/plain": [
"Tensor(shape=(10, 10, 10, 10), inds=('a', 'b', 'c', 'd'), tags=oset([]))"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"t_target = qtn.Tensor(\n",
" data=np.random.uniform(size=(10, 10, 10, 10)),\n",
" inds=('a', 'b', 'c', 'd'),\n",
")\n",
"t_target"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "4c6b020d",
"metadata": {},
"outputs": [],
"source": [
"# normalize for better sense of how good the fit is\n",
"t_target /= t_target.norm()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "8894bcff",
"metadata": {},
"source": [
"The target could also be an arbitrary tensor network.\n",
"\n",
"Now we manually create the decomposed geometry, i.e. a ring of 4 tensors."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "d9353a5c",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"TensorNetwork(tensors=4, indices=8)
Tensor(shape=(10, 5, 5), inds=[a, left, up], tags={}),
backend=numpy, dtype=float64, data=...Tensor(shape=(10, 5, 5), inds=[b, up, right], tags={}),
backend=numpy, dtype=float64, data=...Tensor(shape=(10, 5, 5), inds=[c, right, bottom], tags={}),
backend=numpy, dtype=float64, data=...Tensor(shape=(10, 5, 5), inds=[d, bottom, left], tags={}),
backend=numpy, dtype=float64, data=... "
],
"text/plain": [
"TensorNetwork(tensors=4, indices=8)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rank = 5\n",
"\n",
"tn_guess = qtn.TensorNetwork([\n",
" qtn.Tensor(np.random.normal(size=(10, rank, rank)), inds=('a', 'left', 'up')),\n",
" qtn.Tensor(np.random.normal(size=(10, rank, rank)), inds=('b', 'up', 'right')),\n",
" qtn.Tensor(np.random.normal(size=(10, rank, rank)), inds=('c', 'right', 'bottom')),\n",
" qtn.Tensor(np.random.normal(size=(10, rank, rank)), inds=('d', 'bottom', 'left')),\n",
"])\n",
"tn_guess"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "d83ca2a2",
"metadata": {},
"source": [
"We could have any internal structure, as long as the other indices match (and\n",
"the contraction is possible)."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "c96cafb3",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
""
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"tn_guess.draw(show_inds='all', highlight_inds=['a', 'b', 'c', 'd'])"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "f0a34b16",
"metadata": {},
"source": [
"Compute the initial distance (in terms of frobeius norm):"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "068ddd36",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"np.float64(2450.808525296577)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tn_guess.distance(t_target)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "cde20331",
"metadata": {},
"source": [
"Perform the initial fitting using ALS (alternating least squares), see the\n",
"function [`TensorNetwork.fit`](TensorNetwork.fit) for more details:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "9269e5e8",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"0.4697: 100%|██████████| 1000/1000 [00:01<00:00, 565.85it/s]\n"
]
}
],
"source": [
"tn_fitted = tn_guess.fit(t_target, method='als', steps=1000, progbar=True)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "49e80711",
"metadata": {},
"source": [
"Sometimes, autodiff based optimization can do better than ALS, see\n",
"[`TNOptimizer`](TNOptimizer) for more details:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "66d5011f",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"+0.457260587804 [best: +0.457260587804] : 21%|██▏ | 214/1000 [00:00<00:01, 546.10it/s]\n"
]
},
{
"data": {
"text/html": [
"TensorNetwork(tensors=4, indices=8)
Tensor(shape=(10, 5, 5), inds=[a, left, up], tags={}),
backend=numpy, dtype=float64, data=...Tensor(shape=(10, 5, 5), inds=[b, up, right], tags={}),
backend=numpy, dtype=float64, data=...Tensor(shape=(10, 5, 5), inds=[c, right, bottom], tags={}),
backend=numpy, dtype=float64, data=...Tensor(shape=(10, 5, 5), inds=[d, bottom, left], tags={}),
backend=numpy, dtype=float64, data=... "
],
"text/plain": [
"TensorNetwork(tensors=4, indices=8)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tn_fitted.fit_(t_target, method='autodiff', steps=1000, progbar=True)"
]
},
{
"cell_type": "markdown",
"id": "f25b486e",
"metadata": {},
"source": [
"Double check the new fitted tensor network is close to the target:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "776da76a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"np.float64(0.45725182227607436)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tn_fitted.distance(t_target)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "b1a82a5d",
"metadata": {},
"source": [
"Considering the target as a wavefunction, our fitted network has an overlap of:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "356e54fb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.895422046569203"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tn_fitted @ t_target.H"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "e2afad2d",
"metadata": {},
"source": [
"Note random tensors are generally not that easy to fit, resulting in a not\n",
"great fidelity."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}