User Guide Laplace Reconstruct
Basic usage
This library provides one main interface laplace_reconstruct
which uses a selected inverse Laplace transform algorithm to reconstruct trajectories from a provided parameterized Laplace representation functional \(\mathbf{F}(\mathbf{p},\mathbf{s})\),
Where \(\mathbf{p}\) is a Tensor encoding the initial system state as a latent variable, and \(t\) is the time points to reconstruct trajectories for.
This can be used by
[ ]:
from torchlaplace import laplace_reconstruct
laplace_reconstruct(laplace_rep_func, p, t)
where laplace_rep_func
is any callable implementing the parameterized Laplace representation functional \(\mathbf{F}(\mathbf{p},\mathbf{s})\), p
is a Tensor encoding the initial state of shape \((\text{MiniBatchSize},\text{K})\). Where \(\text{K}\) is a hyperparameter, and can be set by the user. Finally, t
is a Tensor of shape \((\text{MiniBatchSize},\text{SeqLen})\) or \((\text{SeqLen})\) containing the time points to reconstruct the trajectories for.
Note that this is not numerically stable for all ILT methods, however should probably be fine with the default fourier
(fourier series inverse) ILT algorithm.
The parameterized Laplace representation functional laplace_rep_func
, \(\mathbf{F}(\mathbf{p},\mathbf{s})\) also takes an input complex value \(\mathbf{s}\). This \(\mathbf{s}\) is used internally when reconstructing a specified time point with the selected inverse Laplace transform algorithm ilt_algorithm
.
The biggest gotcha is that laplace_rep_func
must be a nn.Module
when using the laplace_rep_func
function. This is due to internally needing to collect the parameters of the parameterized Laplace representation.
To replicate the experiments in [1] see the in the experiments directory.
Example
Define an encoder and a Laplace representation functional F(p, s)
[1]:
from torch import nn
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Model (encoder and Laplace representation func)
class ReverseGRUEncoder(nn.Module):
# Encodes observed trajectory into latent vector
def __init__(self, dimension_in, latent_dim, hidden_units, encode_obs_time=True):
super(ReverseGRUEncoder, self).__init__()
self.encode_obs_time = encode_obs_time
if self.encode_obs_time:
dimension_in += 1
self.gru = nn.GRU(dimension_in, hidden_units, 2, batch_first=True)
self.linear_out = nn.Linear(hidden_units, latent_dim)
nn.init.xavier_uniform_(self.linear_out.weight)
def forward(self, observed_data, observed_tp):
trajs_to_encode = observed_data # (batch_size, t_observed_dim, observed_dim)
if self.encode_obs_time:
trajs_to_encode = torch.cat(
(
observed_data,
observed_tp.view(1, -1, 1).repeat(observed_data.shape[0], 1, 1),
),
dim=2,
)
reversed_trajs_to_encode = torch.flip(trajs_to_encode, (1,))
out, _ = self.gru(reversed_trajs_to_encode)
return self.linear_out(out[:, -1, :])
class LaplaceRepresentationFunc(nn.Module):
# SphereSurfaceModel : C^{b+k} -> C^{bxd} - In Riemann Sphere Co ords : b dim s reconstruction terms, k is latent encoding dimension, d is output dimension
def __init__(self, s_dim, output_dim, latent_dim, hidden_units=64):
super(LaplaceRepresentationFunc, self).__init__()
self.s_dim = s_dim
self.output_dim = output_dim
self.latent_dim = latent_dim
self.linear_tanh_stack = nn.Sequential(
nn.Linear(s_dim * 2 + latent_dim, hidden_units),
nn.Tanh(),
nn.Linear(hidden_units, hidden_units),
nn.Tanh(),
nn.Linear(hidden_units, (s_dim) * 2 * output_dim),
)
for m in self.linear_tanh_stack.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
phi_max = torch.pi / 2.0
self.phi_scale = phi_max - -torch.pi / 2.0
def forward(self, i):
out = self.linear_tanh_stack(i.view(-1, self.s_dim * 2 + self.latent_dim)).view(
-1, 2 * self.output_dim, self.s_dim
)
theta = nn.Tanh()(out[:, : self.output_dim, :]) * torch.pi # From - pi to + pi
phi = (
nn.Tanh()(out[:, self.output_dim :, :]) * self.phi_scale / 2.0
- torch.pi / 2.0
+ self.phi_scale / 2.0
) # Form -pi / 2 to + pi / 2
return theta, phi
Load a dataset
[2]:
from torch.utils.data import DataLoader
from torchlaplace.data_utils import basic_collate_fn
normalize_dataset = True
batch_size = 128
extrapolate = True
def sawtooth(trajectories_to_sample=100, t_nsamples=200):
# Toy sawtooth waveform. Simple to generate, for Differential Equation Datasets see datasets.py (Note more complex DE take time to sample from, in some cases minutes).
t_end = 20.0
t_begin = t_end / t_nsamples
ti = torch.linspace(t_begin, t_end, t_nsamples).to(device)
def sampler(t, x0=0):
return (t + x0) / (2 * torch.pi) - torch.floor((t + x0) / (2 * torch.pi))
x0s = torch.linspace(0, 2 * torch.pi, trajectories_to_sample)
trajs = []
for x0 in x0s:
trajs.append(sampler(ti, x0))
y = torch.stack(trajs)
trajectories = y.view(trajectories_to_sample, -1, 1)
return trajectories, ti
trajectories, t = sawtooth(
trajectories_to_sample=1000,
t_nsamples=200,
)
if normalize_dataset:
samples = trajectories.shape[0]
dim = trajectories.shape[2]
traj = (
torch.reshape(trajectories, (-1, dim))
- torch.reshape(trajectories, (-1, dim)).mean(0)
) / torch.reshape(trajectories, (-1, dim)).std(0)
trajectories = torch.reshape(traj, (samples, -1, dim))
train_split = int(0.8 * trajectories.shape[0])
test_split = int(0.9 * trajectories.shape[0])
traj_index = torch.randperm(trajectories.shape[0])
train_trajectories = trajectories[traj_index[:train_split], :, :]
val_trajectories = trajectories[traj_index[train_split:test_split], :, :]
test_trajectories = trajectories[traj_index[test_split:], :, :]
input_dim = train_trajectories.shape[2]
output_dim = input_dim
dltrain = DataLoader(
train_trajectories,
batch_size=batch_size,
shuffle=True,
collate_fn=lambda batch: basic_collate_fn(
batch,
t,
data_type="train",
extrap=extrapolate,
),
)
dlval = DataLoader(
val_trajectories,
batch_size=batch_size,
shuffle=False,
collate_fn=lambda batch: basic_collate_fn(
batch,
t,
data_type="test",
extrap=extrapolate,
),
)
dltest = DataLoader(
test_trajectories,
batch_size=batch_size,
shuffle=False,
collate_fn=lambda batch: basic_collate_fn(
batch,
t,
data_type="test",
extrap=extrapolate,
),
)
Instantiate model and sample prediction
[3]:
latent_dim = 2
hidden_units = 64
encode_obs_time = True
s_recon_terms = 33
encoder = ReverseGRUEncoder(
input_dim,
latent_dim,
hidden_units // 2,
encode_obs_time=encode_obs_time,
).to(device)
laplace_rep_func = LaplaceRepresentationFunc(
s_recon_terms, output_dim, latent_dim
).to(device)
We can generate predictions using laplace_reconstruct
as follows
[4]:
from torchlaplace import laplace_reconstruct
laplace_rep_func.eval(), encoder.eval()
for batch in dlval:
trajs_to_encode = batch[
"observed_data"
] # (batch_size, t_observed_dim, observed_dim)
observed_tp = batch["observed_tp"] # (1, t_observed_dim)
p = encoder(
trajs_to_encode, observed_tp
) # p is the latent tensor encoding the initial states
tp_to_predict = batch["tp_to_predict"]
predictions = laplace_reconstruct(
laplace_rep_func, p, tp_to_predict, recon_dim=output_dim
)
[5]:
predictions.shape
[5]:
torch.Size([100, 100, 1])
Train end to end
[6]:
from time import time
from copy import deepcopy
learning_rate = 1e-3
epochs = 1000
patience = None
if not patience:
patience = epochs
params = list(laplace_rep_func.parameters()) + list(encoder.parameters())
optimizer = torch.optim.Adam(params, lr=learning_rate)
loss_fn = torch.nn.MSELoss()
best_loss = float("inf")
waiting = 0
for epoch in range(epochs):
iteration = 0
epoch_train_loss_it_cum = 0
start_time = time()
laplace_rep_func.train(), encoder.train()
for batch in dltrain:
optimizer.zero_grad()
trajs_to_encode = batch[
"observed_data"
] # (batch_size, t_observed_dim, observed_dim)
observed_tp = batch["observed_tp"] # (1, t_observed_dim)
p = encoder(
trajs_to_encode, observed_tp
) # p is the latent tensor encoding the initial states
tp_to_predict = batch["tp_to_predict"]
predictions = laplace_reconstruct(
laplace_rep_func, p, tp_to_predict, recon_dim=output_dim
)
loss = loss_fn(
torch.flatten(predictions), torch.flatten(batch["data_to_predict"])
)
loss.backward()
torch.nn.utils.clip_grad_norm_(params, 1)
optimizer.step()
epoch_train_loss_it_cum += loss.item()
iteration += 1
epoch_train_loss = epoch_train_loss_it_cum / iteration
epoch_duration = time() - start_time
# Validation step
laplace_rep_func.eval(), encoder.eval()
cum_val_loss = 0
cum_val_batches = 0
for batch in dlval:
trajs_to_encode = batch[
"observed_data"
] # (batch_size, t_observed_dim, observed_dim)
observed_tp = batch["observed_tp"] # (1, t_observed_dim)
p = encoder(
trajs_to_encode, observed_tp
) # p is the latent tensor encoding the initial states
tp_to_predict = batch["tp_to_predict"]
predictions = laplace_reconstruct(
laplace_rep_func, p, tp_to_predict, recon_dim=output_dim
)
cum_val_loss += loss_fn(
torch.flatten(predictions), torch.flatten(batch["data_to_predict"])
).item()
cum_val_batches += 1
val_mse = cum_val_loss / cum_val_batches
if epoch % 100 == 0:
print(
"[epoch={}] epoch_duration={:.2f} | train_loss={}\t| val_mse={}\t|".format(
epoch, epoch_duration, epoch_train_loss, val_mse
)
)
# Early stopping procedure
if val_mse < best_loss:
best_loss = val_mse
best_laplace_rep_func = deepcopy(laplace_rep_func.state_dict())
best_encoder = deepcopy(encoder.state_dict())
waiting = 0
elif waiting > patience:
break
else:
waiting += 1
# Load best model
laplace_rep_func.load_state_dict(best_laplace_rep_func)
encoder.load_state_dict(best_encoder)
# Test step
laplace_rep_func.eval(), encoder.eval()
cum_test_loss = 0
cum_test_batches = 0
for batch in dltest:
trajs_to_encode = batch[
"observed_data"
] # (batch_size, t_observed_dim, observed_dim)
observed_tp = batch["observed_tp"] # (1, t_observed_dim)
p = encoder(
trajs_to_encode, observed_tp
) # p is the latent tensor encoding the initial states
tp_to_predict = batch["tp_to_predict"]
predictions = laplace_reconstruct(laplace_rep_func, p, tp_to_predict)
cum_test_loss += loss_fn(
torch.flatten(predictions), torch.flatten(batch["data_to_predict"])
).item()
cum_test_batches += 1
test_mse = cum_test_loss / cum_test_batches
print(f"test_mse= {test_mse}")
[epoch=0] epoch_duration=0.14 | train_loss=1.1861835888453893 | val_mse=1.0386548042297363 |
[epoch=100] epoch_duration=0.05 | train_loss=0.26202377889837536 | val_mse=0.28236162662506104 |
[epoch=200] epoch_duration=0.02 | train_loss=0.13380443517650878 | val_mse=0.13365934789180756 |
[epoch=300] epoch_duration=0.02 | train_loss=0.08324573934078217 | val_mse=0.08065777271986008 |
[epoch=400] epoch_duration=0.02 | train_loss=0.06641621195844241 | val_mse=0.06298608332872391 |
[epoch=500] epoch_duration=0.02 | train_loss=0.055839442781039646 | val_mse=0.06235235556960106 |
[epoch=600] epoch_duration=0.02 | train_loss=0.056307445147207806 | val_mse=0.07489385455846786 |
[epoch=700] epoch_duration=0.02 | train_loss=0.051500603024448664 | val_mse=0.045799192041158676 |
[epoch=800] epoch_duration=0.02 | train_loss=0.04638872029525893 | val_mse=0.04121057316660881 |
[epoch=900] epoch_duration=0.02 | train_loss=0.05361673395548548 | val_mse=0.05730640888214111 |
test_mse= 0.05796055868268013