{ "cells": [ { "cell_type": "markdown", "id": "7317794f-440d-4a64-9771-3dc4159a48b3", "metadata": {}, "source": [ "(ex_tracing_tn_functions)=\n", "\n", "# Tracing tensor network functions and reusing intermediates\n", "\n", "This example shows how to use {mod}`autoray.lazy` to lazily 'trace' a tensor network function. One can then inspect the computational graph or optionally compute it using automatically indentified shared intermediates.\n", "\n", "```{hint}\n", "See the main [`autoray` lazy computation docs here](https://autoray.readthedocs.io/en/latest/lazy_computation.html) for the full range of functionality.\n", "```" ] }, { "cell_type": "code", "execution_count": 1, "id": "f167bd92-f3d4-42ac-b11b-9c8557c9ca2a", "metadata": {}, "outputs": [], "source": [ "%config InlineBackend.figure_formats = ['svg']\n", "import autoray as ar\n", "\n", "import quimb as qu\n", "import quimb.tensor as qtn" ] }, { "cell_type": "markdown", "id": "9d71f4ee-92d3-4a77-8c8f-a10e396b0454", "metadata": {}, "source": [ "We'll use the basic example of evaluating a two-site observable at every adjacent pair in an MPS (using a MPO expectation or manually reusing environments would be the most efficient way of doing this usually):" ] }, { "cell_type": "code", "execution_count": 2, "id": "fb1433f7-9e05-4951-a2d0-039d0c535ae2", "metadata": {}, "outputs": [], "source": [ "mps = qtn.MPS_rand_state(20, 10)\n", "G = qu.rand_herm(4, dtype=\"float64\")" ] }, { "cell_type": "markdown", "id": "6ec4a108-0c6e-4545-96ff-85f097cdb876", "metadata": {}, "source": [ "First the eager computation:" ] }, { "cell_type": "code", "execution_count": 3, "id": "54b81c55-d8bc-4336-9dc3-d569d1b70215", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 49.1 ms, sys: 2.49 ms, total: 51.6 ms\n", "Wall time: 91.5 ms\n" ] }, { "data": { "text/plain": [ "np.float64(-2.1507695767849517)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", "ex_eager = sum(\n", " mps.local_expectation_exact(G, (i, i + 1)) for i in range(mps.L - 1)\n", ")\n", "ex_eager" ] }, { "cell_type": "markdown", "id": "6c95590d-cd6c-4ba3-847e-06dbf4e85fe3", "metadata": {}, "source": [ "## Lazily tracing the computation\n", "\n", "Now we can wrap all the underlying arrays as {class}`autoray.lazy.LazyArray` instances:" ] }, { "cell_type": "code", "execution_count": 4, "id": "2e529977-f43a-4ce3-99ef-b5bddb52dbf4", "metadata": {}, "outputs": [], "source": [ "mps.apply_to_arrays(ar.lazy.array)\n", "lG = ar.lazy.array(G)" ] }, { "cell_type": "markdown", "id": "9bdd41d4-bb9a-416c-be31-f4550d1a2ab8", "metadata": {}, "source": [ "Now we when call the same expectation, the operations are only traced:" ] }, { "cell_type": "code", "execution_count": 5, "id": "baefb77c-5bff-4289-ba4a-6a508bee6149", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ex_lazy = sum(\n", " mps.local_expectation_exact(lG, (i, i + 1)) for i in range(mps.L - 1)\n", ")\n", "ex_lazy" ] }, { "cell_type": "markdown", "id": "8b80bd04-8f9d-4aad-89ba-c705b369f1dd", "metadata": {}, "source": [ "That allows us to perform introspection, like the mix of array functions used:" ] }, { "cell_type": "code", "execution_count": 6, "id": "c969a315-4c05-4d8c-a511-209e188bdb9e", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "ex_lazy.plot_history_functions();" ] }, { "cell_type": "markdown", "id": "1ff06e2a-ffff-415a-b378-fd3124c89f38", "metadata": {}, "source": [ "Or the rough memory footprint in terms of concurrent array elements:" ] }, { "cell_type": "code", "execution_count": 7, "id": "9df01402-f2ef-40f9-986d-f7ad70ab8617", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "ex_lazy.plot_history_size_footprint();" ] }, { "cell_type": "markdown", "id": "f3113a55-f946-420d-9abc-2083940fae0e", "metadata": {}, "source": [ "We can compute the actual value by calling [`.compute()`](autoray.lazy.LazyArray.compute), though note this also clears the computational graph, so you should do any introspection first:" ] }, { "cell_type": "code", "execution_count": 8, "id": "53647702-4ae5-4a1a-993c-f904436ccddd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 13.9 ms, sys: 2.17 ms, total: 16 ms\n", "Wall time: 18.2 ms\n" ] }, { "data": { "text/plain": [ "np.float64(-2.1507695767849517)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", "ar.lazy.compute(ex_lazy)" ] }, { "cell_type": "markdown", "id": "e48fb47f-f8f0-42a0-8df3-425c0babf6ad", "metadata": {}, "source": [ "The computation is faster as only the numeric operations are run, all the `quimb` infrastructure that essentially composed the computation ran earlier." ] }, { "cell_type": "markdown", "id": "0454e4c6-8f93-4570-abbf-e9c4057cbbc1", "metadata": {}, "source": [ "## Reusing intermediates:\n", "\n", "One thing that lazily tracing allows is automatic identification of any repeated intermediates (without blindly caching everything):" ] }, { "cell_type": "code", "execution_count": 9, "id": "fe037918-b47f-4d2c-9816-1accb8708997", "metadata": {}, "outputs": [], "source": [ "with ar.lazy.shared_intermediates():\n", " ex_lazy_reuse = sum(\n", " mps.local_expectation_exact(lG, (i, i + 1)) for i in range(mps.L - 1)\n", " )" ] }, { "cell_type": "markdown", "id": "6811a05f-8fe8-485e-a9e0-7ac548239dc1", "metadata": {}, "source": [ "Now we we look at the functions used it should be a far shorter list:" ] }, { "cell_type": "code", "execution_count": 10, "id": "bb5d3b8b-c388-410d-814b-5da49c85ac24", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "ex_lazy_reuse.plot_history_functions();" ] }, { "cell_type": "markdown", "id": "05038e72-0860-442a-9919-865a8a24a09d", "metadata": {}, "source": [ "Though the memory footprint might also go up:" ] }, { "cell_type": "code", "execution_count": 11, "id": "90a03f3b-a15b-4305-8272-07bd4da28e3f", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "ex_lazy_reuse.plot_history_size_footprint();" ] }, { "cell_type": "markdown", "id": "d5b238c2-6399-4122-be8a-75fa448cf96a", "metadata": {}, "source": [ "Note we haven't done anything to favor reuse in the computation here, so slightly different contraction orders might prevent efficient use for example. Generally for best performance one would want to either explicitly favor repeated patterns to maximize cache reuse, or setup intermedates / environments explicitly.\n", "\n", "Again, we can just call `.compute()` to get the value, and it should be slightly quicker:" ] }, { "cell_type": "code", "execution_count": 12, "id": "af1bbab9-5278-4c16-b68d-9c3ee244ed92", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 6.77 ms, sys: 1.28 ms, total: 8.04 ms\n", "Wall time: 15.3 ms\n" ] }, { "data": { "text/plain": [ "np.float64(-2.1507695767849517)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", "ar.lazy.compute(ex_lazy_reuse)" ] }, { "cell_type": "markdown", "id": "c5894994-7260-4f62-90a0-5bab0fc5d63f", "metadata": {}, "source": [ "## Compiling the computation\n", "\n", "For repeated computations one can use `autoray` to set up a pure 'array function'." ] }, { "cell_type": "code", "execution_count": 13, "id": "d3ebd511-c776-4831-b183-56959fdee930", "metadata": {}, "outputs": [], "source": [ "# rebuild our computational graph\n", "with ar.lazy.shared_intermediates():\n", " output = sum(\n", " mps.local_expectation_exact(G, (i, i + 1)) for i in range(mps.L - 1)\n", " )" ] }, { "cell_type": "code", "execution_count": 14, "id": "210f4a7b-5aaa-4eb5-bc11-47c0eb8943bf", "metadata": {}, "outputs": [], "source": [ "# get the function from the input variables to this output\n", "# note that `G` will be treated as constant since we don't supply it here\n", "inputs = list(mps.arrays)\n", "foo = output.get_function(inputs)" ] }, { "cell_type": "markdown", "id": "237bb8fa-5908-466c-85e2-0077b67e56f2", "metadata": {}, "source": [ "Now we can supply any set of matching arrays to this function. And now only the array functions will be called in a `exec` compiled block:" ] }, { "cell_type": "code", "execution_count": 15, "id": "cffc4eb5-220e-4cf0-b820-64cbc5ab230b", "metadata": {}, "outputs": [], "source": [ "# get a new TN with matching arrays\n", "mps2 = qtn.MPS_rand_state(20, 10)" ] }, { "cell_type": "code", "execution_count": 16, "id": "2ef03ebc-f351-4f3b-ba6d-cda7ff92f30b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 5.7 ms, sys: 1.7 ms, total: 7.41 ms\n", "Wall time: 8.82 ms\n" ] }, { "data": { "text/plain": [ "np.float64(-2.2266568847709034)" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", "foo(mps2.arrays)" ] }, { "cell_type": "markdown", "id": "c2020335-63fd-42c1-bc1c-0a96e21d35b8", "metadata": {}, "source": [ "Note the order of the arrays and array axes must match the original computation." ] }, { "cell_type": "markdown", "id": "411613cb-b8b9-4bb6-8fa3-e0f738e171d1", "metadata": {}, "source": [ "Alternatively one can just `ar.autojit` to automate this process, and also use different backends. `autojit` automatically reuses shared intermediates and folds constants." ] }, { "cell_type": "code", "execution_count": 17, "id": "0b7a9a16-a7c2-42c9-852a-a049ffb10019", "metadata": {}, "outputs": [], "source": [ "@ar.autojit\n", "def bar(arrays, G):\n", " mps_trace = mps2.copy()\n", " for t, ary in zip(mps_trace, arrays):\n", " t.modify(data=ary)\n", " return sum(\n", " mps_trace.local_expectation_exact(G, (i, i + 1))\n", " for i in range(mps.L - 1)\n", " )" ] }, { "cell_type": "markdown", "id": "8c318b95", "metadata": {}, "source": [ "This is also the general setup for using other tracing decorators." ] }, { "cell_type": "code", "execution_count": null, "id": "68a46273-5ffc-4c00-87ed-7464d20b824c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 52.5 ms, sys: 3.89 ms, total: 56.4 ms\n", "Wall time: 69.8 ms\n" ] }, { "data": { "text/plain": [ "np.float64(-2.2266568847709034)" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", "# first time there will be some overhead from tracing\n", "bar(mps2.arrays, G)" ] }, { "cell_type": "code", "execution_count": null, "id": "f65c8e95-d013-4068-8377-65f951a2f7bf", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 4.99 ms, sys: 1.62 ms, total: 6.61 ms\n", "Wall time: 12.5 ms\n" ] }, { "data": { "text/plain": [ "np.float64(-2.2266568847709034)" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", "# second time it will be fast\n", "bar(mps2.arrays, G)" ] }, { "cell_type": "markdown", "id": "249db429-b66f-40a7-8c17-2ae4454c73f0", "metadata": {}, "source": [ "Or we can use one of the other compilation backends:" ] }, { "cell_type": "code", "execution_count": 20, "id": "5bffb5be-7807-4bea-9af0-adfb17e44d46", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 2.89 s, sys: 217 ms, total: 3.1 s\n", "Wall time: 1.73 s\n" ] }, { "data": { "text/plain": [ "array(-2.226657, dtype=float32)" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", "bar(mps2.arrays, G, backend=\"jax\")" ] }, { "cell_type": "code", "execution_count": 21, "id": "495166cb-34ad-4e99-af05-12975030de78", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 1.62 ms, sys: 995 μs, total: 2.62 ms\n", "Wall time: 2.03 ms\n" ] }, { "data": { "text/plain": [ "array(-2.226657, dtype=float32)" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", "bar(mps2.arrays, G, backend=\"jax\")" ] }, { "cell_type": "markdown", "id": "c9b8d621", "metadata": {}, "source": [ "```{warning}\n", "Note jax eagerly defaults to single precision unless configured otherwise.\n", "```" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3" } }, "nbformat": 4, "nbformat_minor": 5 }