{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "d20b0ee1",
"metadata": {
"raw_mimetype": "text/restructuredtext",
"tags": []
},
"source": [
"(tensor-network-drawing)=\n",
"\n",
"# Drawing\n",
"\n",
"`quimb` has a lot of functionality for drawing tensor networks that can be\n",
"useful for debugging, interactive development, and producing figures etc.\n",
"This page is a general overview of various options, mostly centered around the\n",
"method [`TensorNetwork.draw`](quimb.tensor.tensor_core.TensorNetwork.draw). Underneath this\n",
"calls [networkx](https://networkx.org/documentation/stable//reference/drawing.html)\n",
"which itself uses [matplotlib](https://matplotlib.org/)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "ce709f65",
"metadata": {},
"outputs": [],
"source": [
"%config InlineBackend.figure_formats = ['svg']\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import quimb.tensor as qtn"
]
},
{
"cell_type": "markdown",
"id": "8b7907d7",
"metadata": {},
"source": [
"We'll use a 3D grid tensor network as our basic example."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d633f126",
"metadata": {},
"outputs": [],
"source": [
"Lx = Ly = Lz = 4\n",
"D = 2\n",
"tn = qtn.TN3D_rand(Lx, Ly, Lz, D=D)"
]
},
{
"cell_type": "markdown",
"id": "c15d54ee",
"metadata": {},
"source": [
"By default bonds are draw proportional to `log2` of their dimension, whereas nodes are fixed in size."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "71116d75",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
""
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"tn.draw()"
]
},
{
"cell_type": "markdown",
"id": "f25bcbd0",
"metadata": {},
"source": [
"By default index names are not shown and tensor tags are only shown for small tensors, these can both be controlled manually like so:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "edc4f2f4",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
""
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"qtn.PEPS.rand(3, 3, D).draw(show_tags=True, show_inds=True)"
]
},
{
"cell_type": "markdown",
"id": "a05cd152",
"metadata": {},
"source": [
"If you want to see inner index names (bonds) as well as the outer index names you need to use `show_inds='all'`:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "2863598d",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
""
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"qtn.PEPS.rand(3, 3, D).draw(show_tags=False, show_inds=\"all\")"
]
},
{
"cell_type": "markdown",
"id": "8bf25f35",
"metadata": {},
"source": [
"## Coloring\n",
"\n",
"The first argument to `draw` is `color=`, which can either be a single tag or a sequence of tags:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f4a9757c",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
""
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# add the same tag to every tensor\n",
"tn.add_tag(\"CUBE\")\n",
"\n",
"# color that tag and each corner of our TN\n",
"color = [\"CUBE\"] + [\n",
" f\"I{i},{j},{k}\"\n",
" for i in (0, Lx - 1)\n",
" for j in (0, Ly - 1)\n",
" for k in (0, Lz - 1)\n",
"]\n",
"\n",
"tn.draw(color)"
]
},
{
"cell_type": "markdown",
"id": "1e16bb66",
"metadata": {},
"source": [
"If you have many tags or are simply only interested in the drawing the colors you can supply the `legend=False` option to turn off the legend.\n",
"`quimb` can show tensor which have multiple matching tags - the style is controlled by the kwarg `multi_tag_style` which should be one of:\n",
"`{\"auto\", \"pie\", \"nest\", \"average\", \"last\"}`."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "e07570ed-0826-4e3e-89ce-36b57df02f88",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
""
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"t = qtn.rand_tensor([2, 3, 4, 5], \"abcd\", [\"W\", \"X\", \"Y\", \"Z\"])\n",
"\n",
"fig, axs = plt.subplots(1, 4)\n",
"\n",
"for i, multi_tag_style in enumerate([\"pie\", \"nest\", \"average\", \"last\"]):\n",
" t.draw(\n",
" [\"W\", \"X\", \"Y\", \"Z\"],\n",
" multi_tag_style=multi_tag_style,\n",
" ax=axs[i],\n",
" node_scale=2,\n",
" title=multi_tag_style,\n",
" legend=False,\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "d84b9155",
"metadata": {
"raw_mimetype": "text/restructuredtext",
"tags": []
},
"source": [
":::{hint}\n",
"`quimb` tries to produce a sequence of colors that are reasonably locally distigushable\n",
"but also have some global ordering when using many colors. These are based on the palette\n",
"designed with color blindness in mind by [Okabe & Ito](https://jfly.uni-koeln.de/color/).\n",
"You can supply custom colors with the `custom_colors=` kwarg.\n",
":::"
]
},
{
"cell_type": "markdown",
"id": "40964e1e",
"metadata": {},
"source": [
"## Highlighting indices\n",
"\n",
"You can visualize a subset of indices by supplying a sequence of them to the `highlight_inds=` kwarg like so:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "8eca3d53",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
""
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# get a central tensor and its indices\n",
"tag = f\"I{Lx // 2},{Ly // 2},{Lz // 2}\"\n",
"t = tn[tag]\n",
"inds = t.inds\n",
"tn.draw(color=tag, highlight_inds=inds, edge_scale=2)"
]
},
{
"cell_type": "markdown",
"id": "6a70c4fb",
"metadata": {},
"source": [
"The color can be controlled with ``highlight_inds_color``."
]
},
{
"cell_type": "markdown",
"id": "8bffed67",
"metadata": {},
"source": [
"## Highlighting `tids`\n",
"\n",
"While tensors can carry arbitrary tags and can usually be identified by these, it is sometimes useful to be able to highlight tensors based on their underlying `tids` - each of which is a unique integer representing a node in the hypergraph."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "e7576d07",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# get the first plane of tensor tids\n",
"tids = list(tn.tensor_map.keys())[: Lx * Ly]\n",
"tids"
]
},
{
"cell_type": "markdown",
"id": "1728c43d",
"metadata": {},
"source": [
"The color can be controlled with `highlight_tids_color`:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "3aae6ec1",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
""
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"tn.draw(highlight_tids=tids, highlight_tids_color=(1.0, 0.0, 0.5, 0.5))"
]
},
{
"cell_type": "markdown",
"id": "47cb51d2",
"metadata": {},
"source": [
"## Auto coloring edges\n",
"\n",
"You can auto color the edges of a tensor network by supplying\n",
"`color_edges=True`, the colors will match the fancy HTML repr colors:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "74812402",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"TensorNetwork2D(tensors=9, indices=12, Lx=3, Ly=3, max_bond=4)
Tensor(shape=(4, 4), inds=[_306947AAADM, _306947AAADN], tags={I0,0, X0, Y0}),
backend=numpy, dtype=float64, data=array([[ 7.18894919e-02, 6.83112048e-01, 7.98424012e-01,\n",
" -1.75315928e+00],\n",
" [-4.78397990e-01, -1.72829849e-01, -6.58915165e-02,\n",
" -5.78033511e-01],\n",
" [-1.11307179e-03, 1.78900359e-01, -4.74561996e-01,\n",
" 9.11421122e-02],\n",
" [ 8.35005029e-01, -9.81594785e-02, -9.41924626e-01,\n",
" -2.16316798e-01]])Tensor(shape=(4, 4, 4), inds=[_306947AAADM, _306947AAADO, _306947AAADP], tags={I0,1, X0, Y1}),
backend=numpy, dtype=float64, data=array([[[-0.30775745, 0.20920997, 0.79080199, 0.58972599],\n",
" [-0.24538878, -0.94418875, -0.21870062, 0.79870996],\n",
" [ 1.51397008, -0.12667633, 0.34650608, 0.33689662],\n",
" [ 0.11136661, 0.63214752, -1.17492166, 1.66933699]],\n",
"\n",
" [[ 0.94062957, 0.31894053, -1.10566016, 0.91120348],\n",
" [ 0.10944591, -2.18485862, 1.62931999, 1.35805721],\n",
" [ 0.52794194, 0.84219343, -1.85341224, -0.03073354],\n",
" [-0.30521565, -0.41232579, -0.58259161, 0.400788 ]],\n",
"\n",
" [[ 2.31687648, 0.49170592, 2.04279984, 1.58247771],\n",
" [-0.21926775, 1.50567372, 0.71969335, 0.62763253],\n",
" [-0.71748364, 0.11445126, 0.43898632, 2.73503653],\n",
" [-0.34248628, -0.26485491, -0.03530308, -1.28211888]],\n",
"\n",
" [[-0.00584095, 1.45948046, -0.21573938, -0.09003705],\n",
" [ 2.86161268, -2.06655336, 0.02657173, -0.42009767],\n",
" [ 0.86507809, 0.35337044, 0.05780541, -0.60567584],\n",
" [ 0.68390896, -1.260495 , 0.91621395, 0.71878702]]])Tensor(shape=(4, 4), inds=[_306947AAADO, _306947AAADQ], tags={I0,2, X0, Y2}),
backend=numpy, dtype=float64, data=array([[-0.88307693, -0.10934434, 0.8502521 , 1.50412047],\n",
" [ 0.40070268, 1.5163725 , 0.00370674, 1.68034423],\n",
" [ 0.05507651, -0.66759186, -1.04188734, -0.50925808],\n",
" [-0.46645879, 0.29744583, 0.81724255, 0.81182092]])Tensor(shape=(4, 4, 4), inds=[_306947AAADR, _306947AAADS, _306947AAADN], tags={I1,0, X1, Y0}),
backend=numpy, dtype=float64, data=array([[[-0.03510623, -1.23109293, -1.4273743 , -0.28980301],\n",
" [ 0.50870193, -1.05230474, 0.00293333, 1.74919136],\n",
" [ 1.55679676, 0.13430554, 0.30705201, 1.84779234],\n",
" [-0.0744133 , 0.59929537, -0.61891297, 0.0388766 ]],\n",
"\n",
" [[-1.95206361, -0.97735562, 0.15411089, -1.9316832 ],\n",
" [-0.10622516, -0.46546306, -1.40251434, 0.88253432],\n",
" [-0.60468444, 0.55852373, -0.27517345, 0.18494187],\n",
" [-0.96765723, 1.31371999, -1.17524199, -0.16235839]],\n",
"\n",
" [[ 1.30346971, 0.0869263 , 0.93319807, 0.47359496],\n",
" [ 0.40612509, -0.21047716, 2.11420177, -0.01957372],\n",
" [-1.20538914, 0.69763654, 0.10030548, -0.75254628],\n",
" [ 0.36828973, 0.35534047, -2.42242563, -1.83870206]],\n",
"\n",
" [[ 0.49562866, -0.02216285, -0.2451352 , -0.2692628 ],\n",
" [ 1.42704321, 0.17341281, -0.70141224, -1.74986127],\n",
" [ 0.25438817, 0.61033938, 0.20565257, 0.97992656],\n",
" [ 0.76200311, -0.36243329, 0.69036873, -1.57897724]]])Tensor(shape=(4, 4, 4, 4), inds=[_306947AAADR, _306947AAADT, _306947AAADU, _306947AAADP], tags={I1,1, X1, Y1}),
backend=numpy, dtype=float64, data=...Tensor(shape=(4, 4, 4), inds=[_306947AAADT, _306947AAADV, _306947AAADQ], tags={I1,2, X1, Y2}),
backend=numpy, dtype=float64, data=array([[[-0.07528802, -0.51622967, -0.86913884, -1.26554368],\n",
" [-1.25775401, 0.99871748, 0.97098278, -0.95752719],\n",
" [-0.40221044, -0.52345585, 0.68273566, 0.0828993 ],\n",
" [ 0.12715919, -0.04971863, 0.99722867, -2.16456399]],\n",
"\n",
" [[-1.031883 , 0.16065785, 0.28173052, -1.06360248],\n",
" [-0.88846348, 1.395764 , 0.91065186, -0.10884679],\n",
" [-0.52964788, -1.8801717 , 1.01502795, -0.95844419],\n",
" [-0.05666429, 1.69871524, 1.72976959, 1.61779722]],\n",
"\n",
" [[ 0.75814837, -0.85038706, -1.20794399, -0.34900027],\n",
" [ 1.62854033, 0.86363444, 0.88794526, -0.81434026],\n",
" [ 0.58930102, 0.37693131, 0.78375778, 0.70503987],\n",
" [-0.50905808, 0.27221988, -0.95752402, -0.41114459]],\n",
"\n",
" [[-1.89099256, -0.66577309, 0.79282967, 0.16050924],\n",
" [ 0.3359393 , -0.72102943, -0.44655247, 1.49675555],\n",
" [-0.29579388, 1.7965346 , -0.04273085, -0.66861549],\n",
" [ 0.63553096, 0.34479635, -1.49041807, 0.20217159]]])Tensor(shape=(4, 4), inds=[_306947AAADW, _306947AAADS], tags={I2,0, X2, Y0}),
backend=numpy, dtype=float64, data=array([[ 0.05067039, 0.66555714, -0.36874829, 0.1919897 ],\n",
" [-0.68919917, -0.02733905, 2.71239777, -1.87658541],\n",
" [ 2.19091597, 0.73201907, -0.46597274, -0.59765479],\n",
" [-2.38186247, -0.8956334 , 0.16494551, -0.67851545]])Tensor(shape=(4, 4, 4), inds=[_306947AAADW, _306947AAADX, _306947AAADU], tags={I2,1, X2, Y1}),
backend=numpy, dtype=float64, data=array([[[-0.09693024, 0.57206792, -0.30883602, 0.55478295],\n",
" [ 1.26617117, -0.23572523, 1.99199968, -0.218294 ],\n",
" [ 0.19260928, 1.30915155, 0.48090251, -0.6622794 ],\n",
" [ 0.22943585, 1.44137072, 0.60272609, -1.03941406]],\n",
"\n",
" [[-0.04167123, -1.20349135, -1.0646733 , 0.36360581],\n",
" [-1.01468192, 0.11312927, -1.12879752, -0.25151874],\n",
" [ 0.11537587, 2.17884752, 0.7679036 , -0.32403415],\n",
" [ 0.76753007, -0.16979341, 0.76997361, -1.24140024]],\n",
"\n",
" [[ 1.07878011, 1.29862861, 0.07829701, 0.13769864],\n",
" [ 0.63588689, 0.01005852, -0.79672232, 0.97837831],\n",
" [-3.08871577, 0.84353071, -0.04410356, -2.7633038 ],\n",
" [ 1.65935788, 0.38740932, 0.7231917 , -1.33316934]],\n",
"\n",
" [[ 0.75290288, -0.77249693, -2.0530331 , 0.97615674],\n",
" [ 0.59753827, -0.3964504 , 0.94553114, 0.23603146],\n",
" [ 1.50978478, 0.06889835, 0.29125099, 0.7892723 ],\n",
" [-1.09148925, -0.07127854, -0.4973722 , 0.00347064]]])Tensor(shape=(4, 4), inds=[_306947AAADX, _306947AAADV], tags={I2,2, X2, Y2}),
backend=numpy, dtype=float64, data=array([[-0.3901993 , 0.4608873 , 1.55969742, 0.05686525],\n",
" [ 0.69718484, -0.60600985, -0.46803938, 0.5403317 ],\n",
" [-0.54055479, -0.37649063, -0.81266062, -0.30343787],\n",
" [ 1.07929128, 1.05148616, 1.10164817, -0.78786361]]) "
],
"text/plain": [
"TensorNetwork2D(tensors=9, indices=12, Lx=3, Ly=3, max_bond=4)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tn2d = qtn.TN2D_rand(3, 3, 4)\n",
"tn2d"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "dd8dadfd",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
""
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"tn2d.draw(edge_color=True, show_inds=\"all\")"
]
},
{
"cell_type": "markdown",
"id": "80d8a40f-9041-4c4d-9f14-d009f019cc5e",
"metadata": {},
"source": [
"## Specifying output indices\n",
"\n",
"The output indices of a tensor network are assumed to be those that appear on a single tensor. For hyper tensor networks where this is not the case you can specify `output_inds` explicitly, to ensure that they will be rended as dangling."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "5cfaf56d-c04d-4eb0-8242-8e071db0292c",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
""
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"Lx = Ly = 3\n",
"htn = qtn.HTN2D_classical_ising_partition_function(Lx, Ly, beta=1.0)\n",
"output_inds = [f\"s{i},{j}\" for i in range(Lx) for j in range(Ly)]\n",
"htn.draw(output_inds=output_inds, edge_color=True)"
]
},
{
"cell_type": "markdown",
"id": "ccac79f9",
"metadata": {},
"source": [
"## Positioning tensors\n",
"\n",
"### Automatic layouts\n",
"\n",
"The automatic layout strategy `quimb` adopts (`layout=\"auto\"`) is to lay the tensors out using some relatively efficient scheme, before 'relaxing' the positions using a (slower) force repulsion algorithm into something usually more natural.\n",
"\n",
"Relevant options are:\n",
"- `layout`: if `\"auto\"` use the options below, else specify a layout directly (with no relaxation, i.e. set `iterations=0`).\n",
"- `initial_layout`: if using relaxation, the starting layout (the default for which is `'kamada_kawai'` for small graphs and `'spectral'` for large graphs).\n",
"- `iterations`: controls the number of force repulsion steps."
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "8437cfd1-6fb2-498b-afac-76c4715fc7b6",
"metadata": {},
"source": [
"Another decent `networxk` choice for the initial layout that you might try if `'kamada_kawai'` isn't producing good results is `'spectral'`. You should also be able to specify most of the [networkx layout algorithms](https://networkx.org/documentation/stable//reference/drawing.html#module-networkx.drawing.layout).\n",
"\n",
"If you have [`pygraphviz`](https://pygraphviz.github.io/) installed then you can use the layouts `\"neato\"`, `\"sfdp\"` and `\"dot\"`."
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "aad98009-c654-4898-ba22-daa9077a8a9e",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"