{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# User Guide Laplace Reconstruct\n", "\n", "\n", "## Basic usage\n", "\n", "This library provides one main interface `laplace_reconstruct` which uses a selected inverse Laplace transform algorithm to reconstruct trajectories from a provided parameterized Laplace representation functional $\\mathbf{F}(\\mathbf{p},\\mathbf{s})$,\n", "\n", "$$\\mathbf{x}(t) = \\text{inverse laplace transform}(\\mathbf{F}(\\mathbf{p},\\mathbf{s}), t)$$\n", "\n", "Where $\\mathbf{p}$ is a Tensor encoding the initial system state as a latent variable, and $t$ is the time points to reconstruct trajectories for.\n", "\n", "This can be used by" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from torchlaplace import laplace_reconstruct\n", "\n", "laplace_reconstruct(laplace_rep_func, p, t)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "where `laplace_rep_func` is any callable implementing the parameterized Laplace representation functional $\\mathbf{F}(\\mathbf{p},\\mathbf{s})$, `p` is a Tensor encoding the initial state of shape $(\\text{MiniBatchSize},\\text{K})$.\n", "Where $\\text{K}$ is a hyperparameter, and can be set by the user.\n", "Finally, `t` is a Tensor of shape $(\\text{MiniBatchSize},\\text{SeqLen})$\n", "or $(\\text{SeqLen})$ containing the time points to reconstruct the trajectories for.\n", "\n", "Note that this is not numerically stable for all ILT methods, however should probably be fine with the default `fourier` (fourier series inverse) ILT algorithm.\n", "\n", "The parameterized Laplace representation functional `laplace_rep_func`, $\\mathbf{F}(\\mathbf{p},\\mathbf{s})$\n", "also takes an input complex value $\\mathbf{s}$.\n", "This $\\mathbf{s}$ is used internally when reconstructing a specified time point with the selected inverse Laplace transform algorithm `ilt_algorithm`.\n", "\n", "The biggest **gotcha** is that `laplace_rep_func` must be a `nn.Module` when using the `laplace_rep_func` function. This is due to internally needing to collect the parameters of the parameterized Laplace representation.\n", "\n", "To replicate the experiments in [[1]](https://arxiv.org/abs/2206.04843) see the in the [experiments](https://github.com/samholt/NeuralLaplace/tree/master/experiments) directory.\n", "\n", "## Example\n", "\n", "\n", "### Define an encoder and a Laplace representation functional F(p, s)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from torch import nn\n", "import torch\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "# Model (encoder and Laplace representation func)\n", "class ReverseGRUEncoder(nn.Module):\n", " # Encodes observed trajectory into latent vector\n", " def __init__(self, dimension_in, latent_dim, hidden_units, encode_obs_time=True):\n", " super(ReverseGRUEncoder, self).__init__()\n", " self.encode_obs_time = encode_obs_time\n", " if self.encode_obs_time:\n", " dimension_in += 1\n", " self.gru = nn.GRU(dimension_in, hidden_units, 2, batch_first=True)\n", " self.linear_out = nn.Linear(hidden_units, latent_dim)\n", " nn.init.xavier_uniform_(self.linear_out.weight)\n", "\n", " def forward(self, observed_data, observed_tp):\n", " trajs_to_encode = observed_data # (batch_size, t_observed_dim, observed_dim)\n", " if self.encode_obs_time:\n", " trajs_to_encode = torch.cat(\n", " (\n", " observed_data,\n", " observed_tp.view(1, -1, 1).repeat(observed_data.shape[0], 1, 1),\n", " ),\n", " dim=2,\n", " )\n", " reversed_trajs_to_encode = torch.flip(trajs_to_encode, (1,))\n", " out, _ = self.gru(reversed_trajs_to_encode)\n", " return self.linear_out(out[:, -1, :])\n", "\n", "class LaplaceRepresentationFunc(nn.Module):\n", " # SphereSurfaceModel : C^{b+k} -> C^{bxd} - In Riemann Sphere Co ords : b dim s reconstruction terms, k is latent encoding dimension, d is output dimension\n", " def __init__(self, s_dim, output_dim, latent_dim, hidden_units=64):\n", " super(LaplaceRepresentationFunc, self).__init__()\n", " self.s_dim = s_dim\n", " self.output_dim = output_dim\n", " self.latent_dim = latent_dim\n", " self.linear_tanh_stack = nn.Sequential(\n", " nn.Linear(s_dim * 2 + latent_dim, hidden_units),\n", " nn.Tanh(),\n", " nn.Linear(hidden_units, hidden_units),\n", " nn.Tanh(),\n", " nn.Linear(hidden_units, (s_dim) * 2 * output_dim),\n", " )\n", "\n", " for m in self.linear_tanh_stack.modules():\n", " if isinstance(m, nn.Linear):\n", " nn.init.xavier_uniform_(m.weight)\n", "\n", " phi_max = torch.pi / 2.0\n", " self.phi_scale = phi_max - -torch.pi / 2.0\n", "\n", " def forward(self, i):\n", " out = self.linear_tanh_stack(i.view(-1, self.s_dim * 2 + self.latent_dim)).view(\n", " -1, 2 * self.output_dim, self.s_dim\n", " )\n", " theta = nn.Tanh()(out[:, : self.output_dim, :]) * torch.pi # From - pi to + pi\n", " phi = (\n", " nn.Tanh()(out[:, self.output_dim :, :]) * self.phi_scale / 2.0\n", " - torch.pi / 2.0\n", " + self.phi_scale / 2.0\n", " ) # Form -pi / 2 to + pi / 2\n", " return theta, phi\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load a dataset\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", "from torchlaplace.data_utils import basic_collate_fn\n", "\n", "normalize_dataset = True\n", "batch_size = 128\n", "extrapolate = True\n", "\n", "def sawtooth(trajectories_to_sample=100, t_nsamples=200):\n", " # Toy sawtooth waveform. Simple to generate, for Differential Equation Datasets see datasets.py (Note more complex DE take time to sample from, in some cases minutes).\n", " t_end = 20.0\n", " t_begin = t_end / t_nsamples\n", " ti = torch.linspace(t_begin, t_end, t_nsamples).to(device)\n", "\n", " def sampler(t, x0=0):\n", " return (t + x0) / (2 * torch.pi) - torch.floor((t + x0) / (2 * torch.pi))\n", "\n", " x0s = torch.linspace(0, 2 * torch.pi, trajectories_to_sample)\n", " trajs = []\n", " for x0 in x0s:\n", " trajs.append(sampler(ti, x0))\n", " y = torch.stack(trajs)\n", " trajectories = y.view(trajectories_to_sample, -1, 1)\n", " return trajectories, ti\n", "\n", "\n", "trajectories, t = sawtooth(\n", " trajectories_to_sample=1000,\n", " t_nsamples=200,\n", ")\n", "if normalize_dataset:\n", " samples = trajectories.shape[0]\n", " dim = trajectories.shape[2]\n", " traj = (\n", " torch.reshape(trajectories, (-1, dim))\n", " - torch.reshape(trajectories, (-1, dim)).mean(0)\n", " ) / torch.reshape(trajectories, (-1, dim)).std(0)\n", " trajectories = torch.reshape(traj, (samples, -1, dim))\n", "train_split = int(0.8 * trajectories.shape[0])\n", "test_split = int(0.9 * trajectories.shape[0])\n", "traj_index = torch.randperm(trajectories.shape[0])\n", "train_trajectories = trajectories[traj_index[:train_split], :, :]\n", "val_trajectories = trajectories[traj_index[train_split:test_split], :, :]\n", "test_trajectories = trajectories[traj_index[test_split:], :, :]\n", "\n", "input_dim = train_trajectories.shape[2]\n", "output_dim = input_dim\n", "dltrain = DataLoader(\n", " train_trajectories,\n", " batch_size=batch_size,\n", " shuffle=True,\n", " collate_fn=lambda batch: basic_collate_fn(\n", " batch,\n", " t,\n", " data_type=\"train\",\n", " extrap=extrapolate,\n", " ),\n", ")\n", "dlval = DataLoader(\n", " val_trajectories,\n", " batch_size=batch_size,\n", " shuffle=False,\n", " collate_fn=lambda batch: basic_collate_fn(\n", " batch,\n", " t,\n", " data_type=\"test\",\n", " extrap=extrapolate,\n", " ),\n", ")\n", "dltest = DataLoader(\n", " test_trajectories,\n", " batch_size=batch_size,\n", " shuffle=False,\n", " collate_fn=lambda batch: basic_collate_fn(\n", " batch,\n", " t,\n", " data_type=\"test\",\n", " extrap=extrapolate,\n", " ),\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Instantiate model and sample prediction" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "latent_dim = 2\n", "hidden_units = 64\n", "encode_obs_time = True\n", "s_recon_terms = 33\n", "\n", "encoder = ReverseGRUEncoder(\n", " input_dim,\n", " latent_dim,\n", " hidden_units // 2,\n", " encode_obs_time=encode_obs_time,\n", ").to(device)\n", "laplace_rep_func = LaplaceRepresentationFunc(\n", " s_recon_terms, output_dim, latent_dim\n", ").to(device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can generate predictions using `laplace_reconstruct` as follows" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from torchlaplace import laplace_reconstruct\n", "\n", "laplace_rep_func.eval(), encoder.eval()\n", "for batch in dlval:\n", " trajs_to_encode = batch[\n", " \"observed_data\"\n", " ] # (batch_size, t_observed_dim, observed_dim)\n", " observed_tp = batch[\"observed_tp\"] # (1, t_observed_dim)\n", " p = encoder(\n", " trajs_to_encode, observed_tp\n", " ) # p is the latent tensor encoding the initial states\n", " tp_to_predict = batch[\"tp_to_predict\"]\n", " predictions = laplace_reconstruct(\n", " laplace_rep_func, p, tp_to_predict, recon_dim=output_dim\n", " )" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([100, 100, 1])" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictions.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train end to end " ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[epoch=0] epoch_duration=0.14 | train_loss=1.1861835888453893\t| val_mse=1.0386548042297363\t|\n", "[epoch=100] epoch_duration=0.05 | train_loss=0.26202377889837536\t| val_mse=0.28236162662506104\t|\n", "[epoch=200] epoch_duration=0.02 | train_loss=0.13380443517650878\t| val_mse=0.13365934789180756\t|\n", "[epoch=300] epoch_duration=0.02 | train_loss=0.08324573934078217\t| val_mse=0.08065777271986008\t|\n", "[epoch=400] epoch_duration=0.02 | train_loss=0.06641621195844241\t| val_mse=0.06298608332872391\t|\n", "[epoch=500] epoch_duration=0.02 | train_loss=0.055839442781039646\t| val_mse=0.06235235556960106\t|\n", "[epoch=600] epoch_duration=0.02 | train_loss=0.056307445147207806\t| val_mse=0.07489385455846786\t|\n", "[epoch=700] epoch_duration=0.02 | train_loss=0.051500603024448664\t| val_mse=0.045799192041158676\t|\n", "[epoch=800] epoch_duration=0.02 | train_loss=0.04638872029525893\t| val_mse=0.04121057316660881\t|\n", "[epoch=900] epoch_duration=0.02 | train_loss=0.05361673395548548\t| val_mse=0.05730640888214111\t|\n", "test_mse= 0.05796055868268013\n" ] } ], "source": [ "from time import time\n", "from copy import deepcopy\n", "\n", "learning_rate = 1e-3\n", "epochs = 1000\n", "patience = None\n", "\n", "if not patience:\n", " patience = epochs\n", "\n", "params = list(laplace_rep_func.parameters()) + list(encoder.parameters())\n", "optimizer = torch.optim.Adam(params, lr=learning_rate)\n", "loss_fn = torch.nn.MSELoss()\n", "\n", "best_loss = float(\"inf\")\n", "waiting = 0\n", "\n", "for epoch in range(epochs):\n", " iteration = 0\n", " epoch_train_loss_it_cum = 0\n", " start_time = time()\n", " laplace_rep_func.train(), encoder.train()\n", " for batch in dltrain:\n", " optimizer.zero_grad()\n", " trajs_to_encode = batch[\n", " \"observed_data\"\n", " ] # (batch_size, t_observed_dim, observed_dim)\n", " observed_tp = batch[\"observed_tp\"] # (1, t_observed_dim)\n", " p = encoder(\n", " trajs_to_encode, observed_tp\n", " ) # p is the latent tensor encoding the initial states\n", " tp_to_predict = batch[\"tp_to_predict\"]\n", " predictions = laplace_reconstruct(\n", " laplace_rep_func, p, tp_to_predict, recon_dim=output_dim\n", " )\n", " loss = loss_fn(\n", " torch.flatten(predictions), torch.flatten(batch[\"data_to_predict\"])\n", " )\n", " loss.backward()\n", " torch.nn.utils.clip_grad_norm_(params, 1)\n", " optimizer.step()\n", " epoch_train_loss_it_cum += loss.item()\n", " iteration += 1\n", " epoch_train_loss = epoch_train_loss_it_cum / iteration\n", " epoch_duration = time() - start_time\n", "\n", " # Validation step\n", " laplace_rep_func.eval(), encoder.eval()\n", " cum_val_loss = 0\n", " cum_val_batches = 0\n", " for batch in dlval:\n", " trajs_to_encode = batch[\n", " \"observed_data\"\n", " ] # (batch_size, t_observed_dim, observed_dim)\n", " observed_tp = batch[\"observed_tp\"] # (1, t_observed_dim)\n", " p = encoder(\n", " trajs_to_encode, observed_tp\n", " ) # p is the latent tensor encoding the initial states\n", " tp_to_predict = batch[\"tp_to_predict\"]\n", " predictions = laplace_reconstruct(\n", " laplace_rep_func, p, tp_to_predict, recon_dim=output_dim\n", " )\n", " cum_val_loss += loss_fn(\n", " torch.flatten(predictions), torch.flatten(batch[\"data_to_predict\"])\n", " ).item()\n", " cum_val_batches += 1\n", " val_mse = cum_val_loss / cum_val_batches\n", " if epoch % 100 == 0:\n", " print(\n", " \"[epoch={}] epoch_duration={:.2f} | train_loss={}\\t| val_mse={}\\t|\".format(\n", " epoch, epoch_duration, epoch_train_loss, val_mse\n", " )\n", " )\n", "\n", " # Early stopping procedure\n", " if val_mse < best_loss:\n", " best_loss = val_mse\n", " best_laplace_rep_func = deepcopy(laplace_rep_func.state_dict())\n", " best_encoder = deepcopy(encoder.state_dict())\n", " waiting = 0\n", " elif waiting > patience:\n", " break\n", " else:\n", " waiting += 1\n", "\n", "# Load best model\n", "laplace_rep_func.load_state_dict(best_laplace_rep_func)\n", "encoder.load_state_dict(best_encoder)\n", "\n", "# Test step\n", "laplace_rep_func.eval(), encoder.eval()\n", "cum_test_loss = 0\n", "cum_test_batches = 0\n", "for batch in dltest:\n", " trajs_to_encode = batch[\n", " \"observed_data\"\n", " ] # (batch_size, t_observed_dim, observed_dim)\n", " observed_tp = batch[\"observed_tp\"] # (1, t_observed_dim)\n", " p = encoder(\n", " trajs_to_encode, observed_tp\n", " ) # p is the latent tensor encoding the initial states\n", " tp_to_predict = batch[\"tp_to_predict\"]\n", " predictions = laplace_reconstruct(laplace_rep_func, p, tp_to_predict)\n", " cum_test_loss += loss_fn(\n", " torch.flatten(predictions), torch.flatten(batch[\"data_to_predict\"])\n", " ).item()\n", " cum_test_batches += 1\n", "test_mse = cum_test_loss / cum_test_batches\n", "print(f\"test_mse= {test_mse}\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.9.7 ('nl3')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "17023f5485d9c7f86e59b9fa31080efc0b1a5cffb20b1f365d13ed3a535852e7" } } }, "nbformat": 4, "nbformat_minor": 2 }