Skip to content

Repository for the 2022 MLMI work, "Neural State-Space Modeling with Latent Causal-Effect Disentanglement."

Notifications You must be signed in to change notification settings

maryamTolou/causal-effect-neural-ssm

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

22 Commits
 
 
 
 
 
 

Repository files navigation

Neural State-Space Modeling with
Latent Causal-Effect Disentanglement
[MICCAI-MLMI, arXiv]

This repository holds the experiments and models as explored in the work, "Neural State-Space Modeling with Latent Causal-Effect Disentanglement," for the 13th Machine Learning in Medical Imaging (MLMI 2022) Workshop.

Overview

Despite substantial progress in deep learning approaches to time-series reconstruction, no existing methods are designed to uncover local activities with minute signal strength due to their negligible contribution to the optimization loss. Such local activities however can signify important abnormal events in physiological systems, such as an extra foci triggering an abnormal propagation of electrical waves in the heart. We discuss a novel technique for reconstructing such local activity that, while small in signal strength, is the cause of subsequent global activities that have larger signal strength. Our central innovation is to approach this by explicitly modeling and disentangling how the latent state of asystem is influenced by potential hidden internal interventions. In a novel neural formulation of state-space models (SSMs), we first introduce causal-effect modeling of the latent dynamics via a system of interacting neural ODEs that separately describes 1) the continuous-time dynamics of the internal intervention, and 2) its effect on the trajectory of the system’s native state. Because the intervention can not be directly observed but have to be disentangled from the observed subsequent effect, we integrate knowledge of the native intervention-free dynamics of a system, and infer the hidden intervention by assuming it to be responsible for differences observed between the actual and hypothetical intervention-free dynamics. We demonstrated a proof of concept of the presented framework on reconstructing ectopic foci disrupting the course of normal cardiac electrical propagation from remote observations.

framework schematic

Fig 1. Schematic of the prposed ODE-VAE-IM model.

Citation

If you use portions of this repository or have found use for the model in your research directions, please consider citing:

@misc{toloubidokhti2022neural,
      title={Neural State-Space Modeling with Latent Causal-Effect Disentanglement}, 
      author={Maryam Toloubidokhti and Ryan Missel and Xiajun Jiang and Niels Otani and Linwei Wang},
      year={2022},
      eprint={2209.12387},
      archivePrefix={arXiv}
}

Setup

  • A requirements.txt file is provided to handle all used packages within the repository, automatically generated using the Python package pigar.
  • Trained checkpoints for each model are provided in the experiments folder.
  • Pre-generated datasets used in the testing reconstruction visualizations are provided in the vals/ folder for each model.
  • Code on how to build and train neural state-space models (and ODE-VAEs) are provided in Model Code folder.

Data

We provide data generation scripts in the Data Generation folder. To use the data in this codebase, simply move the generated files Intervention and Normal to the Pacing and Normal folders respectively. Then toggle the function parameter newload to True in data_loader.py and run the code to generate each datasets' files.'

The Normal dataset refers to the base dynamics presented in the paper which are native transmembrane potential sets containing 1000 voltage maps, in which the initial excitation locations are chosen randomly across the 100*100mm 2D grid. The Intervention dataset refers to the transmembrane potentials when an extra Foci is present and contains 705 samples with varying initial locations and times for both the excitation and extra Foci.

Generating the Normal dynamics dataset:

  1. (Skip 1 if you wish to use the given forward matrix H.mat) You can use the generate_H_3d.m code to generate the forward matrix with the desired parameters (details are commented in the code). The H matrix used to generate the dataset for this paper, is provided as H.mat. H is loaded in line 20 in the generate_H_3d.m.
  2. Use MATLAB to run the normal.m code. The first 1000 iterations generate the TMP-BSP pairs assuming there is only one Pacing location at the first time-step. The next 1000 generate the extra Foci in a random location in the first time step.
  3. Data will be saved in the 'Normal' folder. TMPs will be in 'TMP' folder and corresponding BSPs will be saved in 'BSP' folder.

Generating the Intervention dataset (extra Foci):

  1. Use the same H used in the Normal dynamics dataset. H is loaded in line 20 of the code extra_pacing.m.
  2. Use MATLAB to run the extra_pacing.m code. After generation of each sample, you are asked whether you want to save this data or not. Input 1 if you wish to save the sample. This option is used to discard the samples were the extra Foci does not happen and the sample looks exactly like the normal dynamics data.
  3. Data will be saved in the 'Intervention' folder. TMPs will be in 'TMP' folder and corresponding BSPs will be saved in 'BSP' folder.

Models

We provide 3 models in Model Code, utilizing Pytorch-Lightning modules to handle the train/test steps. These include:

  • ODE-VAE: The base dynamics model that uses a CNN spatial encoder/decoder to a latent state zi and a neural ODE dynamics network to perform the forecasting. Given k initial frames of observation, it forecasts fully out to T timesteps.
  • ODE-VAE-GRU: The given baseline to compare again, which uses the same ODE dynamics network but additionally includes uses a sliding window of forward observations to update the latent state with a GRU cell at every timestep. The sliding window is fed into the same inital latent state encoder used to initialize z0
  • ODE-VAE-IM: The proposed intervention model which consists of a system of interactin neural ODEs to separately describe 1.) the latent dynamics of the internal intervention and 2.) what effect it has on the resulting trajectory of the system's native dynamics. It leverages a pre-trained ODE-VAE on the native dynamics and learns another neural ODE F(a) to describe the intervention dynamics. The prediction of the states post-intervention are handled through a coupled ODE system F(z + a), in which the intervention dynamics directly affect the ODE vector predictions in the integration step.

Running the script

One can either:

  • Manually change the hyperparameters in the script and simply run python3 train.py
  • Or specify the arguments in the command line, e.g. python3 train.py --model ode_vae --version normal --batch_size 32

Intervention reconstruction examples

reconstruction examples

Fig 2. Reconstruction of electrical propagation in which ectopic foci occurs.

reconstruction examples 2

Fig 3. Reconstruction of electrical propagation in which ectopic foci occurs.

Latent norm ablations

latent norm examples

Fig 4. Visualizations of the L2-Norm of system and intervention states over time.

About

Repository for the 2022 MLMI work, "Neural State-Space Modeling with Latent Causal-Effect Disentanglement."

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 80.4%
  • MATLAB 19.6%