User Guide Inverse Laplace Transform Algorithms

Laplace Theory Background

The Laplace transform of trajectory \(\mathbf{x}\) is defined as

\[\mathbf{F}(\mathbf{s})=\mathcal{L}\{\mathbf{x}\}(\mathbf{s})=\int_0^\infty e^{-\mathbf{s}t} \mathbf{x}(t) dt\]

where \(\mathbf{s}\in \mathbb{C}^d\) is a vector of \textit{complex} numbers and \(\mathbf{F}(\mathbf{s}) \in \mathbb{C}^d\) is called the Laplace representation. The \(\mathbf{F}(\mathbf{s})\) may have singularities, i.e. points where \(\mathbf{F}(\mathbf{s})\to \mathbf{\infty}\) for one component.

For further background details and references see [1].

The inverse Laplace transform (ILT) is defined as

\[\hat{\mathbf{x}}(t) = \mathcal{L}^{-1}\{\mathbf{F}(\mathbf{s})\}(t)=\frac{1}{2\pi i} \int_{\sigma - i \infty}^{\sigma + i \infty} \mathbf{F}(\mathbf{s})e^{\mathbf{s}t}d\mathbf{s}\]

where the integral refers to the Bromwich contour integral in \(\mathbb{C}^d\) with the contour \(\sigma>0\) chosen such that all the singularities of \(\mathbf{F}(\mathbf{s})\) are to the left of it [1].

Many algorithms have been developed to numerically evaluate the ILT Equation (above). On a high level, they involve two steps:

\[\mathcal{Q}(t) = \text{ILT-Query} (t)\]
\[\hat{\mathbf{x}}(t) = \text{ILT-Compute}\big(\{\mathbf{F}(\mathbf{s})| \mathbf{s} \in \mathcal{Q}(t) \}\big)\]
To evaluate \(\mathbf{x}(t)\) on time points \(t \in \mathcal{T} \subset \mathbb{R}\), the algorithms first construct a set of \textit{query points} \(\mathbf{s} \in \mathcal{Q}(\mathcal{T}) \subset \mathbb{C}\). They then compute \(\hat{\mathbf{x}}(t)\) using the \(\mathbf{F}(\mathbf{s})\) evaluated on these points. The number of query points scales \textit{linearly} with the number of time points, i.e. \(|\mathcal{Q}(\mathcal{T})| = b |\mathcal{T}|\), where the constant \(b > 1\), denotes the number of reconstruction terms per time point and is specific to the algorithm. Importantly, the computation complexity of ILT only depends on the \textit{number} of time points, but not their values (e.g. ILT for \(t=0\) and \(t=100\) requires the same amount of computation).
The vast majority of ILT algorithms are differentiable with respect to \(\mathbf{F}(\mathbf{s})\), which allows the gradients to be back propagated through the ILT transform [1].

Code examples of ILT algorithms

Example showing how to use all the inverse Laplace Transform algorithms individually, with a known Laplace representation \(\mathbf{F}(\mathbf{p},\mathbf{s})\), and time points to evaluate for.

[4]:
from time import strftime, time
import numpy as np
import torch

from torchlaplace.inverse_laplace import CME, DeHoog, FixedTablot, Fourier, Stehfest

time_points_to_reconstruct = 1000
s_recon_terms = 33
device = "cuda" if torch.cuda.is_available() else "cpu"


np.random.seed(999)
torch.random.manual_seed(0)

t = torch.linspace(0.0001, 10.0, time_points_to_reconstruct).to(device)

# Cosine
def fs(so):
    return so / (so**2 + 1)  # Laplace solution

def ft(t):
    return torch.cos(t)  # Time solution

FixedTablot

Evaluate s points per time input (Default, as more accurate inversion)

[5]:
decoder = FixedTablot(ilt_reconstruction_terms=s_recon_terms).to(device)
t0 = time()
f_hat_t = decoder(fs, t)
print(
    "FixedTablot Loss:\t{}\t\t| time: {}".format(
        np.sqrt(torch.nn.MSELoss()(ft(t), f_hat_t).cpu().numpy()), time() - t0
    )
)

# Split evaluation of s points out from that of the line integral (should be the exact same result as above)
decoder = FixedTablot(ilt_reconstruction_terms=s_recon_terms).to(device)
t0 = time()
s, _ = decoder.compute_s(t)
fh = fs(s)
f_hat_t = decoder.line_integrate(fh, t)
print(
    "FixedTablot Loss (Split apart):\t{}\t| time: {}".format(
        np.sqrt(torch.nn.MSELoss()(ft(t), f_hat_t).cpu().numpy()), time() - t0
    )
)

decoder = FixedTablot(ilt_reconstruction_terms=s_recon_terms).to(device)
t0 = time()
s, _ = decoder.compute_s(t, time_max=torch.max(t))
fh = fs(s)
f_hat_t = decoder.line_integrate(fh, t, time_max=t.max().item())
print(
    "FixedTablot Loss (Split apart, Fixed Max Time):\t{}\t| time: {}".format(
        np.sqrt(torch.nn.MSELoss()(ft(t), f_hat_t).cpu().numpy()), time() - t0
    )
)

# Evaluate s points for one fixed time, maximum time (Less accurate, maybe more stable ?)

decoder = FixedTablot(ilt_reconstruction_terms=s_recon_terms).to(device)
t0 = time()
f_hat_t = decoder(fs, t, time_max=torch.max(t))
print(
    "FixedTablot Loss (Fixed Max Time):\t{}\t| time: {}".format(
        np.sqrt(torch.nn.MSELoss()(ft(t), f_hat_t).cpu().numpy()), time() - t0
    )
)
FixedTablot Loss:       0.4364858865737915              | time: 0.001544952392578125
FixedTablot Loss (Split apart): 0.4364858567714691      | time: 0.0011589527130126953
FixedTablot Loss (Split apart, Fixed Max Time): 1077.5205078125 | time: 0.0008168220520019531
FixedTablot Loss (Fixed Max Time):      1077.5205078125 | time: 0.000659942626953125

Stehfest

(Increasing degree here, introduces numerical error that increases larger than other methods, therefore for high degree becomes unstable.)

[6]:

decoder = Stehfest(ilt_reconstruction_terms=s_recon_terms).to(device) t0 = time() f_hat_t = decoder(fs, t) print( "Stehfest Loss:\t{}\t| time: {}".format( np.sqrt(torch.nn.MSELoss()(ft(t), f_hat_t).cpu().numpy()), time() - t0 ) ) decoder = Stehfest(ilt_reconstruction_terms=s_recon_terms).to(device) t0 = time() s = decoder.compute_s(t) fh = fs(s) f_hat_t = decoder.line_integrate(fh, t) print( "Stehfest Loss (Split apart):\t{}\t| time: {}".format( np.sqrt(torch.nn.MSELoss()(ft(t), f_hat_t).cpu().numpy()), time() - t0 ) )
Stehfest Loss:  25.583345413208008      | time: 0.0011260509490966797
Stehfest Loss (Split apart):    25.583345413208008      | time: 0.0007259845733642578

Fourier

(Un accelerated DeHoog)

[7]:
decoder = Fourier(ilt_reconstruction_terms=s_recon_terms).to(device)
t0 = time()
f_hat_t = decoder(fs, t)
print(
    "Fourier (Un accelerated DeHoog) Loss:\t{}\t| time: {}".format(
        np.sqrt(torch.nn.MSELoss()(ft(t), f_hat_t).cpu().numpy()), time() - t0
    )
)

decoder = Fourier(ilt_reconstruction_terms=s_recon_terms).to(device)
t0 = time()
s, T = decoder.compute_s(t)
fh = fs(s)
f_hat_t = decoder.line_integrate(fh, t, T)
print(
    "Fourier (Un accelerated DeHoog) Loss (Split apart):\t{}\t| time: {}".format(
        np.sqrt(torch.nn.MSELoss()(ft(t), f_hat_t).cpu().numpy()), time() - t0
    )
)
Fourier (Un accelerated DeHoog) Loss:   0.01714298129081726     | time: 0.0023889541625976562
Fourier (Un accelerated DeHoog) Loss (Split apart):     0.01714298129081726     | time: 0.0010972023010253906

DeHoog

[8]:
decoder = DeHoog(ilt_reconstruction_terms=s_recon_terms).to(device)
t0 = time()
f_hat_t = decoder(fs, t)
print(
    "DeHoog Loss:\t{}\t| time: {}".format(
        np.sqrt(torch.nn.MSELoss()(ft(t), f_hat_t).cpu().numpy()), time() - t0
    )
)

# Split evaluation of s points out from that of the line integral (should be the exact same result as above)
decoder = DeHoog(ilt_reconstruction_terms=s_recon_terms).to(device)
t0 = time()
s, T = decoder.compute_s(t)
fh = fs(s)
f_hat_t = decoder.line_integrate(fh, t, T)
print(
    "DeHoog Loss (Split apart):\t{}\t| time: {}".format(
        np.sqrt(torch.nn.MSELoss()(ft(t), f_hat_t).cpu().numpy()), time() - t0
    )
)

# Single line integral
decoder = DeHoog(ilt_reconstruction_terms=s_recon_terms).to(device)
t0 = time()
s = decoder.compute_fixed_s(torch.max(t))
fh = fs(s)
f_hat_t = decoder.fixed_line_integrate(fh, t, torch.max(t))
print(
    "DeHoog Loss (Fixed Line Integrate):\t{}\t| time: {}".format(
        np.sqrt(torch.nn.MSELoss()(ft(t), f_hat_t).cpu().numpy()), time() - t0
    )
)

decoder = DeHoog(ilt_reconstruction_terms=s_recon_terms).to(device)
t0 = time()
f_hat_t = decoder(fs, t, time_max=torch.max(t))
print(
    "DeHoog Loss (Fixed Max Time):\t{}\t| time: {}".format(
        np.sqrt(torch.nn.MSELoss()(ft(t), f_hat_t).cpu().numpy()), time() - t0
    )
)
DeHoog Loss:    1.2498872820287943e-05  | time: 0.016785144805908203
DeHoog Loss (Split apart):      1.2498872820287943e-05  | time: 0.018165111541748047
DeHoog Loss (Fixed Line Integrate):     0.0342152863740921      | time: 0.0032520294189453125
DeHoog Loss (Fixed Max Time):   0.03613712266087532     | time: 0.012520313262939453

CME

[9]:
decoder = CME(ilt_reconstruction_terms=s_recon_terms).to(device)
t0 = time()
f_hat_t = decoder(fs, t)
print(
    "CME Loss:\t{}\t| time: {}".format(
        np.sqrt(torch.nn.MSELoss()(ft(t), f_hat_t).cpu().numpy()), time() - t0
    )
)

decoder = CME(ilt_reconstruction_terms=s_recon_terms).to(device)
t0 = time()
s, T = decoder.compute_s(t)
fh = fs(s)
f_hat_t = decoder.line_integrate(fh, t, T)
print(
    "CME Loss (Split apart):\t{}\t| time: {}".format(
        np.sqrt(torch.nn.MSELoss()(ft(t), f_hat_t).cpu().numpy()), time() - t0
    )
)
CME Loss:       0.0068940832279622555   | time: 0.0011489391326904297
CME Loss (Split apart): 0.0068940832279622555   | time: 0.0009069442749023438