{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# User Guide Inverse Laplace Transform Algorithms\n", "\n", "## Laplace Theory Background\n", "\n", "The Laplace transform of trajectory $\\mathbf{x}$ is defined as\n", "\n", "$$\\mathbf{F}(\\mathbf{s})=\\mathcal{L}\\{\\mathbf{x}\\}(\\mathbf{s})=\\int_0^\\infty e^{-\\mathbf{s}t} \\mathbf{x}(t) dt $$\n", "\n", "where $\\mathbf{s}\\in \\mathbb{C}^d$ is a vector of \\textit{complex} numbers and $\\mathbf{F}(\\mathbf{s}) \\in \\mathbb{C}^d$ is called the *Laplace representation*.\n", "The $\\mathbf{F}(\\mathbf{s})$ may have singularities, i.e. points where $\\mathbf{F}(\\mathbf{s})\\to \\mathbf{\\infty}$ for one component.\n", "\n", "For further background details and references see [[1]](https://arxiv.org/abs/2206.04843).\n", "\n", "The inverse Laplace transform (ILT) is defined as\n", "\n", "$$ \\hat{\\mathbf{x}}(t) = \\mathcal{L}^{-1}\\{\\mathbf{F}(\\mathbf{s})\\}(t)=\\frac{1}{2\\pi i} \\int_{\\sigma - i \\infty}^{\\sigma + i \\infty} \\mathbf{F}(\\mathbf{s})e^{\\mathbf{s}t}d\\mathbf{s} $$\n", "\n", "where the integral refers to the Bromwich contour integral in $\\mathbb{C}^d$ with the contour $\\sigma>0$ chosen such that all the singularities of $\\mathbf{F}(\\mathbf{s})$ are to the left of it [[1]](https://arxiv.org/abs/2206.04843).\n", "\n", "Many algorithms have been developed to numerically evaluate the ILT Equation (above). On a high level, they involve two steps:\n", "\n", "$$ \\mathcal{Q}(t) = \\text{ILT-Query} (t) $$\n", "$$ \\hat{\\mathbf{x}}(t) = \\text{ILT-Compute}\\big(\\{\\mathbf{F}(\\mathbf{s})| \\mathbf{s} \\in \\mathcal{Q}(t) \\}\\big) $$\n", "\n", "To evaluate $\\mathbf{x}(t)$ on time points $t \\in \\mathcal{T} \\subset \\mathbb{R}$, the algorithms first construct a set of \\textit{query points} $\\mathbf{s} \\in \\mathcal{Q}(\\mathcal{T}) \\subset \\mathbb{C}$. They then compute $\\hat{\\mathbf{x}}(t)$ using the $\\mathbf{F}(\\mathbf{s})$ evaluated on these points.\n", "The number of query points scales \\textit{linearly} with the number of time points, i.e. $|\\mathcal{Q}(\\mathcal{T})| = b |\\mathcal{T}|$, where the constant $b > 1$, denotes the number of reconstruction terms per time point and is specific to the algorithm. \n", "Importantly, the computation complexity of ILT only depends on the \\textit{number} of time points, but not their values (e.g. ILT for $t=0$ and $t=100$ requires the same amount of computation). \n", "The vast majority of ILT algorithms are differentiable with respect to $\\mathbf{F}(\\mathbf{s})$, which allows the gradients to be back propagated through the ILT transform [[1]](https://arxiv.org/abs/2206.04843).\n", "\n", "## Code examples of ILT algorithms\n", "\n", "Example showing how to use all the inverse Laplace Transform algorithms individually, with a known Laplace representation $\\mathbf{F}(\\mathbf{p},\\mathbf{s})$, and time points to evaluate for." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from time import strftime, time\n", "import numpy as np\n", "import torch\n", "\n", "from torchlaplace.inverse_laplace import CME, DeHoog, FixedTablot, Fourier, Stehfest\n", "\n", "time_points_to_reconstruct = 1000\n", "s_recon_terms = 33\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "\n", "\n", "np.random.seed(999)\n", "torch.random.manual_seed(0)\n", "\n", "t = torch.linspace(0.0001, 10.0, time_points_to_reconstruct).to(device)\n", "\n", "# Cosine\n", "def fs(so):\n", " return so / (so**2 + 1) # Laplace solution\n", "\n", "def ft(t):\n", " return torch.cos(t) # Time solution" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### FixedTablot\n", "\n", "Evaluate s points per time input (Default, as more accurate inversion)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "FixedTablot Loss:\t0.4364858865737915\t\t| time: 0.001544952392578125\n", "FixedTablot Loss (Split apart):\t0.4364858567714691\t| time: 0.0011589527130126953\n", "FixedTablot Loss (Split apart, Fixed Max Time):\t1077.5205078125\t| time: 0.0008168220520019531\n", "FixedTablot Loss (Fixed Max Time):\t1077.5205078125\t| time: 0.000659942626953125\n" ] } ], "source": [ "decoder = FixedTablot(ilt_reconstruction_terms=s_recon_terms).to(device)\n", "t0 = time()\n", "f_hat_t = decoder(fs, t)\n", "print(\n", " \"FixedTablot Loss:\\t{}\\t\\t| time: {}\".format(\n", " np.sqrt(torch.nn.MSELoss()(ft(t), f_hat_t).cpu().numpy()), time() - t0\n", " )\n", ")\n", "\n", "# Split evaluation of s points out from that of the line integral (should be the exact same result as above)\n", "decoder = FixedTablot(ilt_reconstruction_terms=s_recon_terms).to(device)\n", "t0 = time()\n", "s, _ = decoder.compute_s(t)\n", "fh = fs(s)\n", "f_hat_t = decoder.line_integrate(fh, t)\n", "print(\n", " \"FixedTablot Loss (Split apart):\\t{}\\t| time: {}\".format(\n", " np.sqrt(torch.nn.MSELoss()(ft(t), f_hat_t).cpu().numpy()), time() - t0\n", " )\n", ")\n", "\n", "decoder = FixedTablot(ilt_reconstruction_terms=s_recon_terms).to(device)\n", "t0 = time()\n", "s, _ = decoder.compute_s(t, time_max=torch.max(t))\n", "fh = fs(s)\n", "f_hat_t = decoder.line_integrate(fh, t, time_max=t.max().item())\n", "print(\n", " \"FixedTablot Loss (Split apart, Fixed Max Time):\\t{}\\t| time: {}\".format(\n", " np.sqrt(torch.nn.MSELoss()(ft(t), f_hat_t).cpu().numpy()), time() - t0\n", " )\n", ")\n", "\n", "# Evaluate s points for one fixed time, maximum time (Less accurate, maybe more stable ?)\n", "\n", "decoder = FixedTablot(ilt_reconstruction_terms=s_recon_terms).to(device)\n", "t0 = time()\n", "f_hat_t = decoder(fs, t, time_max=torch.max(t))\n", "print(\n", " \"FixedTablot Loss (Fixed Max Time):\\t{}\\t| time: {}\".format(\n", " np.sqrt(torch.nn.MSELoss()(ft(t), f_hat_t).cpu().numpy()), time() - t0\n", " )\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "### Stehfest\n", "\n", "(Increasing degree here, introduces numerical error that increases larger than other methods, therefore for high degree becomes unstable.)\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Stehfest Loss:\t25.583345413208008\t| time: 0.0011260509490966797\n", "Stehfest Loss (Split apart):\t25.583345413208008\t| time: 0.0007259845733642578\n" ] } ], "source": [ "\n", "decoder = Stehfest(ilt_reconstruction_terms=s_recon_terms).to(device)\n", "t0 = time()\n", "f_hat_t = decoder(fs, t)\n", "print(\n", " \"Stehfest Loss:\\t{}\\t| time: {}\".format(\n", " np.sqrt(torch.nn.MSELoss()(ft(t), f_hat_t).cpu().numpy()), time() - t0\n", " )\n", ")\n", "\n", "decoder = Stehfest(ilt_reconstruction_terms=s_recon_terms).to(device)\n", "t0 = time()\n", "s = decoder.compute_s(t)\n", "fh = fs(s)\n", "f_hat_t = decoder.line_integrate(fh, t)\n", "print(\n", " \"Stehfest Loss (Split apart):\\t{}\\t| time: {}\".format(\n", " np.sqrt(torch.nn.MSELoss()(ft(t), f_hat_t).cpu().numpy()), time() - t0\n", " )\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Fourier\n", "\n", "(Un accelerated DeHoog)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Fourier (Un accelerated DeHoog) Loss:\t0.01714298129081726\t| time: 0.0023889541625976562\n", "Fourier (Un accelerated DeHoog) Loss (Split apart):\t0.01714298129081726\t| time: 0.0010972023010253906\n" ] } ], "source": [ "decoder = Fourier(ilt_reconstruction_terms=s_recon_terms).to(device)\n", "t0 = time()\n", "f_hat_t = decoder(fs, t)\n", "print(\n", " \"Fourier (Un accelerated DeHoog) Loss:\\t{}\\t| time: {}\".format(\n", " np.sqrt(torch.nn.MSELoss()(ft(t), f_hat_t).cpu().numpy()), time() - t0\n", " )\n", ")\n", "\n", "decoder = Fourier(ilt_reconstruction_terms=s_recon_terms).to(device)\n", "t0 = time()\n", "s, T = decoder.compute_s(t)\n", "fh = fs(s)\n", "f_hat_t = decoder.line_integrate(fh, t, T)\n", "print(\n", " \"Fourier (Un accelerated DeHoog) Loss (Split apart):\\t{}\\t| time: {}\".format(\n", " np.sqrt(torch.nn.MSELoss()(ft(t), f_hat_t).cpu().numpy()), time() - t0\n", " )\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### DeHoog" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "DeHoog Loss:\t1.2498872820287943e-05\t| time: 0.016785144805908203\n", "DeHoog Loss (Split apart):\t1.2498872820287943e-05\t| time: 0.018165111541748047\n", "DeHoog Loss (Fixed Line Integrate):\t0.0342152863740921\t| time: 0.0032520294189453125\n", "DeHoog Loss (Fixed Max Time):\t0.03613712266087532\t| time: 0.012520313262939453\n" ] } ], "source": [ "decoder = DeHoog(ilt_reconstruction_terms=s_recon_terms).to(device)\n", "t0 = time()\n", "f_hat_t = decoder(fs, t)\n", "print(\n", " \"DeHoog Loss:\\t{}\\t| time: {}\".format(\n", " np.sqrt(torch.nn.MSELoss()(ft(t), f_hat_t).cpu().numpy()), time() - t0\n", " )\n", ")\n", "\n", "# Split evaluation of s points out from that of the line integral (should be the exact same result as above)\n", "decoder = DeHoog(ilt_reconstruction_terms=s_recon_terms).to(device)\n", "t0 = time()\n", "s, T = decoder.compute_s(t)\n", "fh = fs(s)\n", "f_hat_t = decoder.line_integrate(fh, t, T)\n", "print(\n", " \"DeHoog Loss (Split apart):\\t{}\\t| time: {}\".format(\n", " np.sqrt(torch.nn.MSELoss()(ft(t), f_hat_t).cpu().numpy()), time() - t0\n", " )\n", ")\n", "\n", "# Single line integral\n", "decoder = DeHoog(ilt_reconstruction_terms=s_recon_terms).to(device)\n", "t0 = time()\n", "s = decoder.compute_fixed_s(torch.max(t))\n", "fh = fs(s)\n", "f_hat_t = decoder.fixed_line_integrate(fh, t, torch.max(t))\n", "print(\n", " \"DeHoog Loss (Fixed Line Integrate):\\t{}\\t| time: {}\".format(\n", " np.sqrt(torch.nn.MSELoss()(ft(t), f_hat_t).cpu().numpy()), time() - t0\n", " )\n", ")\n", "\n", "decoder = DeHoog(ilt_reconstruction_terms=s_recon_terms).to(device)\n", "t0 = time()\n", "f_hat_t = decoder(fs, t, time_max=torch.max(t))\n", "print(\n", " \"DeHoog Loss (Fixed Max Time):\\t{}\\t| time: {}\".format(\n", " np.sqrt(torch.nn.MSELoss()(ft(t), f_hat_t).cpu().numpy()), time() - t0\n", " )\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### CME" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CME Loss:\t0.0068940832279622555\t| time: 0.0011489391326904297\n", "CME Loss (Split apart):\t0.0068940832279622555\t| time: 0.0009069442749023438\n" ] } ], "source": [ "decoder = CME(ilt_reconstruction_terms=s_recon_terms).to(device)\n", "t0 = time()\n", "f_hat_t = decoder(fs, t)\n", "print(\n", " \"CME Loss:\\t{}\\t| time: {}\".format(\n", " np.sqrt(torch.nn.MSELoss()(ft(t), f_hat_t).cpu().numpy()), time() - t0\n", " )\n", ")\n", "\n", "decoder = CME(ilt_reconstruction_terms=s_recon_terms).to(device)\n", "t0 = time()\n", "s, T = decoder.compute_s(t)\n", "fh = fs(s)\n", "f_hat_t = decoder.line_integrate(fh, t, T)\n", "print(\n", " \"CME Loss (Split apart):\\t{}\\t| time: {}\".format(\n", " np.sqrt(torch.nn.MSELoss()(ft(t), f_hat_t).cpu().numpy()), time() - t0\n", " )\n", ")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.10.4 ('nl1')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "2b7f629d6ed90f49c5498428e19562b8bf096794f35bacb74d7d55f53bb5faf9" } } }, "nbformat": 4, "nbformat_minor": 2 }