User Guide Laplace Reconstruct

Basic usage

This library provides one main interface laplace_reconstruct which uses a selected inverse Laplace transform algorithm to reconstruct trajectories from a provided parameterized Laplace representation functional \(\mathbf{F}(\mathbf{p},\mathbf{s})\),

\[\mathbf{x}(t) = \text{inverse laplace transform}(\mathbf{F}(\mathbf{p},\mathbf{s}), t)\]

Where \(\mathbf{p}\) is a Tensor encoding the initial system state as a latent variable, and \(t\) is the time points to reconstruct trajectories for.

This can be used by

[ ]:
from torchlaplace import laplace_reconstruct

laplace_reconstruct(laplace_rep_func, p, t)

where laplace_rep_func is any callable implementing the parameterized Laplace representation functional \(\mathbf{F}(\mathbf{p},\mathbf{s})\), p is a Tensor encoding the initial state of shape \((\text{MiniBatchSize},\text{K})\). Where \(\text{K}\) is a hyperparameter, and can be set by the user. Finally, t is a Tensor of shape \((\text{MiniBatchSize},\text{SeqLen})\) or \((\text{SeqLen})\) containing the time points to reconstruct the trajectories for.

Note that this is not numerically stable for all ILT methods, however should probably be fine with the default fourier (fourier series inverse) ILT algorithm.

The parameterized Laplace representation functional laplace_rep_func, \(\mathbf{F}(\mathbf{p},\mathbf{s})\) also takes an input complex value \(\mathbf{s}\). This \(\mathbf{s}\) is used internally when reconstructing a specified time point with the selected inverse Laplace transform algorithm ilt_algorithm.

The biggest gotcha is that laplace_rep_func must be a nn.Module when using the laplace_rep_func function. This is due to internally needing to collect the parameters of the parameterized Laplace representation.

To replicate the experiments in [1] see the in the experiments directory.

Example

Define an encoder and a Laplace representation functional F(p, s)

[1]:
from torch import nn
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Model (encoder and Laplace representation func)
class ReverseGRUEncoder(nn.Module):
    # Encodes observed trajectory into latent vector
    def __init__(self, dimension_in, latent_dim, hidden_units, encode_obs_time=True):
        super(ReverseGRUEncoder, self).__init__()
        self.encode_obs_time = encode_obs_time
        if self.encode_obs_time:
            dimension_in += 1
        self.gru = nn.GRU(dimension_in, hidden_units, 2, batch_first=True)
        self.linear_out = nn.Linear(hidden_units, latent_dim)
        nn.init.xavier_uniform_(self.linear_out.weight)

    def forward(self, observed_data, observed_tp):
        trajs_to_encode = observed_data  # (batch_size, t_observed_dim, observed_dim)
        if self.encode_obs_time:
            trajs_to_encode = torch.cat(
                (
                    observed_data,
                    observed_tp.view(1, -1, 1).repeat(observed_data.shape[0], 1, 1),
                ),
                dim=2,
            )
        reversed_trajs_to_encode = torch.flip(trajs_to_encode, (1,))
        out, _ = self.gru(reversed_trajs_to_encode)
        return self.linear_out(out[:, -1, :])

class LaplaceRepresentationFunc(nn.Module):
    # SphereSurfaceModel : C^{b+k} -> C^{bxd} - In Riemann Sphere Co ords : b dim s reconstruction terms, k is latent encoding dimension, d is output dimension
    def __init__(self, s_dim, output_dim, latent_dim, hidden_units=64):
        super(LaplaceRepresentationFunc, self).__init__()
        self.s_dim = s_dim
        self.output_dim = output_dim
        self.latent_dim = latent_dim
        self.linear_tanh_stack = nn.Sequential(
            nn.Linear(s_dim * 2 + latent_dim, hidden_units),
            nn.Tanh(),
            nn.Linear(hidden_units, hidden_units),
            nn.Tanh(),
            nn.Linear(hidden_units, (s_dim) * 2 * output_dim),
        )

        for m in self.linear_tanh_stack.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)

        phi_max = torch.pi / 2.0
        self.phi_scale = phi_max - -torch.pi / 2.0

    def forward(self, i):
        out = self.linear_tanh_stack(i.view(-1, self.s_dim * 2 + self.latent_dim)).view(
            -1, 2 * self.output_dim, self.s_dim
        )
        theta = nn.Tanh()(out[:, : self.output_dim, :]) * torch.pi  # From - pi to + pi
        phi = (
            nn.Tanh()(out[:, self.output_dim :, :]) * self.phi_scale / 2.0
            - torch.pi / 2.0
            + self.phi_scale / 2.0
        )  # Form -pi / 2 to + pi / 2
        return theta, phi

Load a dataset

[2]:
from torch.utils.data import DataLoader
from torchlaplace.data_utils import basic_collate_fn

normalize_dataset = True
batch_size = 128
extrapolate = True

def sawtooth(trajectories_to_sample=100, t_nsamples=200):
    # Toy sawtooth waveform. Simple to generate, for Differential Equation Datasets see datasets.py (Note more complex DE take time to sample from, in some cases minutes).
    t_end = 20.0
    t_begin = t_end / t_nsamples
    ti = torch.linspace(t_begin, t_end, t_nsamples).to(device)

    def sampler(t, x0=0):
        return (t + x0) / (2 * torch.pi) - torch.floor((t + x0) / (2 * torch.pi))

    x0s = torch.linspace(0, 2 * torch.pi, trajectories_to_sample)
    trajs = []
    for x0 in x0s:
        trajs.append(sampler(ti, x0))
    y = torch.stack(trajs)
    trajectories = y.view(trajectories_to_sample, -1, 1)
    return trajectories, ti


trajectories, t = sawtooth(
    trajectories_to_sample=1000,
    t_nsamples=200,
)
if normalize_dataset:
    samples = trajectories.shape[0]
    dim = trajectories.shape[2]
    traj = (
        torch.reshape(trajectories, (-1, dim))
        - torch.reshape(trajectories, (-1, dim)).mean(0)
    ) / torch.reshape(trajectories, (-1, dim)).std(0)
    trajectories = torch.reshape(traj, (samples, -1, dim))
train_split = int(0.8 * trajectories.shape[0])
test_split = int(0.9 * trajectories.shape[0])
traj_index = torch.randperm(trajectories.shape[0])
train_trajectories = trajectories[traj_index[:train_split], :, :]
val_trajectories = trajectories[traj_index[train_split:test_split], :, :]
test_trajectories = trajectories[traj_index[test_split:], :, :]

input_dim = train_trajectories.shape[2]
output_dim = input_dim
dltrain = DataLoader(
    train_trajectories,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=lambda batch: basic_collate_fn(
        batch,
        t,
        data_type="train",
        extrap=extrapolate,
    ),
)
dlval = DataLoader(
    val_trajectories,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=lambda batch: basic_collate_fn(
        batch,
        t,
        data_type="test",
        extrap=extrapolate,
    ),
)
dltest = DataLoader(
    test_trajectories,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=lambda batch: basic_collate_fn(
        batch,
        t,
        data_type="test",
        extrap=extrapolate,
    ),
)

Instantiate model and sample prediction

[3]:
latent_dim = 2
hidden_units = 64
encode_obs_time = True
s_recon_terms = 33

encoder = ReverseGRUEncoder(
    input_dim,
    latent_dim,
    hidden_units // 2,
    encode_obs_time=encode_obs_time,
).to(device)
laplace_rep_func = LaplaceRepresentationFunc(
    s_recon_terms, output_dim, latent_dim
).to(device)

We can generate predictions using laplace_reconstruct as follows

[4]:
from torchlaplace import laplace_reconstruct

laplace_rep_func.eval(), encoder.eval()
for batch in dlval:
    trajs_to_encode = batch[
        "observed_data"
    ]  # (batch_size, t_observed_dim, observed_dim)
    observed_tp = batch["observed_tp"]  # (1, t_observed_dim)
    p = encoder(
        trajs_to_encode, observed_tp
    )  # p is the latent tensor encoding the initial states
    tp_to_predict = batch["tp_to_predict"]
    predictions = laplace_reconstruct(
        laplace_rep_func, p, tp_to_predict, recon_dim=output_dim
    )
[5]:
predictions.shape
[5]:
torch.Size([100, 100, 1])

Train end to end

[6]:
from time import time
from copy import deepcopy

learning_rate = 1e-3
epochs = 1000
patience = None

if not patience:
    patience = epochs

params = list(laplace_rep_func.parameters()) + list(encoder.parameters())
optimizer = torch.optim.Adam(params, lr=learning_rate)
loss_fn = torch.nn.MSELoss()

best_loss = float("inf")
waiting = 0

for epoch in range(epochs):
    iteration = 0
    epoch_train_loss_it_cum = 0
    start_time = time()
    laplace_rep_func.train(), encoder.train()
    for batch in dltrain:
        optimizer.zero_grad()
        trajs_to_encode = batch[
            "observed_data"
        ]  # (batch_size, t_observed_dim, observed_dim)
        observed_tp = batch["observed_tp"]  # (1, t_observed_dim)
        p = encoder(
            trajs_to_encode, observed_tp
        )  # p is the latent tensor encoding the initial states
        tp_to_predict = batch["tp_to_predict"]
        predictions = laplace_reconstruct(
            laplace_rep_func, p, tp_to_predict, recon_dim=output_dim
        )
        loss = loss_fn(
            torch.flatten(predictions), torch.flatten(batch["data_to_predict"])
        )
        loss.backward()
        torch.nn.utils.clip_grad_norm_(params, 1)
        optimizer.step()
        epoch_train_loss_it_cum += loss.item()
        iteration += 1
    epoch_train_loss = epoch_train_loss_it_cum / iteration
    epoch_duration = time() - start_time

    # Validation step
    laplace_rep_func.eval(), encoder.eval()
    cum_val_loss = 0
    cum_val_batches = 0
    for batch in dlval:
        trajs_to_encode = batch[
            "observed_data"
        ]  # (batch_size, t_observed_dim, observed_dim)
        observed_tp = batch["observed_tp"]  # (1, t_observed_dim)
        p = encoder(
            trajs_to_encode, observed_tp
        )  # p is the latent tensor encoding the initial states
        tp_to_predict = batch["tp_to_predict"]
        predictions = laplace_reconstruct(
            laplace_rep_func, p, tp_to_predict, recon_dim=output_dim
        )
        cum_val_loss += loss_fn(
            torch.flatten(predictions), torch.flatten(batch["data_to_predict"])
        ).item()
        cum_val_batches += 1
    val_mse = cum_val_loss / cum_val_batches
    if epoch % 100 == 0:
        print(
            "[epoch={}] epoch_duration={:.2f} | train_loss={}\t| val_mse={}\t|".format(
                epoch, epoch_duration, epoch_train_loss, val_mse
            )
        )

    # Early stopping procedure
    if val_mse < best_loss:
        best_loss = val_mse
        best_laplace_rep_func = deepcopy(laplace_rep_func.state_dict())
        best_encoder = deepcopy(encoder.state_dict())
        waiting = 0
    elif waiting > patience:
        break
    else:
        waiting += 1

# Load best model
laplace_rep_func.load_state_dict(best_laplace_rep_func)
encoder.load_state_dict(best_encoder)

# Test step
laplace_rep_func.eval(), encoder.eval()
cum_test_loss = 0
cum_test_batches = 0
for batch in dltest:
    trajs_to_encode = batch[
        "observed_data"
    ]  # (batch_size, t_observed_dim, observed_dim)
    observed_tp = batch["observed_tp"]  # (1, t_observed_dim)
    p = encoder(
        trajs_to_encode, observed_tp
    )  # p is the latent tensor encoding the initial states
    tp_to_predict = batch["tp_to_predict"]
    predictions = laplace_reconstruct(laplace_rep_func, p, tp_to_predict)
    cum_test_loss += loss_fn(
        torch.flatten(predictions), torch.flatten(batch["data_to_predict"])
    ).item()
    cum_test_batches += 1
test_mse = cum_test_loss / cum_test_batches
print(f"test_mse= {test_mse}")
[epoch=0] epoch_duration=0.14 | train_loss=1.1861835888453893   | val_mse=1.0386548042297363    |
[epoch=100] epoch_duration=0.05 | train_loss=0.26202377889837536        | val_mse=0.28236162662506104   |
[epoch=200] epoch_duration=0.02 | train_loss=0.13380443517650878        | val_mse=0.13365934789180756   |
[epoch=300] epoch_duration=0.02 | train_loss=0.08324573934078217        | val_mse=0.08065777271986008   |
[epoch=400] epoch_duration=0.02 | train_loss=0.06641621195844241        | val_mse=0.06298608332872391   |
[epoch=500] epoch_duration=0.02 | train_loss=0.055839442781039646       | val_mse=0.06235235556960106   |
[epoch=600] epoch_duration=0.02 | train_loss=0.056307445147207806       | val_mse=0.07489385455846786   |
[epoch=700] epoch_duration=0.02 | train_loss=0.051500603024448664       | val_mse=0.045799192041158676  |
[epoch=800] epoch_duration=0.02 | train_loss=0.04638872029525893        | val_mse=0.04121057316660881   |
[epoch=900] epoch_duration=0.02 | train_loss=0.05361673395548548        | val_mse=0.05730640888214111   |
test_mse= 0.05796055868268013