Neural State-Space Modeling with
Latent Causal-Effect Disentanglement
[MICCAI-MLMI, arXiv]
This repository holds the experiments and models as explored in the work, "Neural State-Space Modeling with Latent Causal-Effect Disentanglement," for the 13th Machine Learning in Medical Imaging (MLMI 2022) Workshop.
Despite substantial progress in deep learning approaches to time-series reconstruction, no existing methods are designed to uncover local activities with minute signal strength due to their negligible contribution to the optimization loss. Such local activities however can signify important abnormal events in physiological systems, such as an extra foci triggering an abnormal propagation of electrical waves in the heart. We discuss a novel technique for reconstructing such local activity that, while small in signal strength, is the cause of subsequent global activities that have larger signal strength. Our central innovation is to approach this by explicitly modeling and disentangling how the latent state of asystem is influenced by potential hidden internal interventions. In a novel neural formulation of state-space models (SSMs), we first introduce causal-effect modeling of the latent dynamics via a system of interacting neural ODEs that separately describes 1) the continuous-time dynamics of the internal intervention, and 2) its effect on the trajectory of the system’s native state. Because the intervention can not be directly observed but have to be disentangled from the observed subsequent effect, we integrate knowledge of the native intervention-free dynamics of a system, and infer the hidden intervention by assuming it to be responsible for differences observed between the actual and hypothetical intervention-free dynamics. We demonstrated a proof of concept of the presented framework on reconstructing ectopic foci disrupting the course of normal cardiac electrical propagation from remote observations.
Fig 1. Schematic of the prposed ODE-VAE-IM model.
If you use portions of this repository or have found use for the model in your research directions, please consider citing:
@misc{toloubidokhti2022neural,
title={Neural State-Space Modeling with Latent Causal-Effect Disentanglement},
author={Maryam Toloubidokhti and Ryan Missel and Xiajun Jiang and Niels Otani and Linwei Wang},
year={2022},
eprint={2209.12387},
archivePrefix={arXiv}
}
- A requirements.txt file is provided to handle all used packages within the repository, automatically generated using the Python package pigar.
- Trained checkpoints for each model are provided in the
experiments
folder. - Pre-generated datasets used in the testing reconstruction visualizations are provided in the
vals/
folder for each model. - Code on how to build and train neural state-space models (and ODE-VAEs) are provided in
Model Code
folder.
We provide data generation scripts in the Data Generation
folder. To use the data in this codebase, simply move the
generated files Intervention
and Normal
to the Pacing
and Normal
folders respectively. Then toggle the
function parameter newload
to True in data_loader.py
and run the code to generate each datasets' files.'
The Normal
dataset refers to the base dynamics presented in the paper which are native transmembrane potential sets containing 1000 voltage maps, in which the initial excitation locations are chosen randomly across the 100*100mm 2D grid. The Intervention
dataset refers to the transmembrane potentials when an extra Foci is present and contains 705
samples with varying initial locations and times for both the excitation and extra Foci.
Generating the Normal
dynamics dataset:
- (Skip 1 if you wish to use the given forward matrix H.mat) You can use the
generate_H_3d.m
code to generate the forward matrix with the desired parameters (details are commented in the code). The H matrix used to generate the dataset for this paper, is provided asH.mat
. H is loaded in line 20 in thegenerate_H_3d.m
. - Use MATLAB to run the
normal.m
code. The first 1000 iterations generate the TMP-BSP pairs assuming there is only one Pacing location at the first time-step. The next 1000 generate the extra Foci in a random location in the first time step. - Data will be saved in the 'Normal' folder. TMPs will be in 'TMP' folder and corresponding BSPs will be saved in 'BSP' folder.
Generating the Intervention
dataset (extra Foci):
- Use the same H used in the Normal dynamics dataset. H is loaded in line 20 of the code
extra_pacing.m
. - Use MATLAB to run the
extra_pacing.m
code. After generation of each sample, you are asked whether you want to save this data or not. Input 1 if you wish to save the sample. This option is used to discard the samples were the extra Foci does not happen and the sample looks exactly like the normal dynamics data. - Data will be saved in the 'Intervention' folder. TMPs will be in 'TMP' folder and corresponding BSPs will be saved in 'BSP' folder.
We provide 3 models in Model Code
, utilizing Pytorch-Lightning modules to handle the train/test steps. These include:
- ODE-VAE: The base dynamics model that uses a CNN spatial encoder/decoder to a latent state
zi
and a neural ODE dynamics network to perform the forecasting. Givenk
initial frames of observation, it forecasts fully out toT
timesteps. - ODE-VAE-GRU: The given baseline to compare again, which uses the same ODE dynamics network but additionally includes uses a sliding window of forward observations to update the latent state with a GRU cell at every timestep. The sliding window is fed into the same inital latent state encoder used to initialize
z0
- ODE-VAE-IM: The proposed intervention model which consists of a system of interactin neural ODEs to separately describe 1.) the latent dynamics of the internal intervention and 2.) what effect it has on the resulting trajectory of the system's native dynamics. It leverages a pre-trained ODE-VAE on the native dynamics and learns another neural ODE F(a) to describe the intervention dynamics. The prediction of the states post-intervention are handled through a coupled ODE system F(z + a), in which the intervention dynamics directly affect the ODE vector predictions in the integration step.
One can either:
- Manually change the hyperparameters in the script and simply run
python3 train.py
- Or specify the arguments in the command line, e.g.
python3 train.py --model ode_vae --version normal --batch_size 32
Fig 2. Reconstruction of electrical propagation in which ectopic foci occurs.
Fig 3. Reconstruction of electrical propagation in which ectopic foci occurs.
Fig 4. Visualizations of the L2-Norm of system and intervention states over time.