MinD-Vis is a framework for decoding human visual stimuli from brain recording. This document introduces the precesedures required for replicating the results in Seeing Beyond the Brain: Masked Modeling Conditioned Diffusion Model for Human Vision Decoding (Submitted to CVPR2022)
Decoding visual stimuli from brain recordings aims to deepen our understanding of the human visual system and build a solid foundation for bridging human and computer vision through the Brain-Computer Interface. However, due to the scarcity of data annotations and the complexity of underlying brain information, it is challenging to decode images with faithful details and meaningful semantics. In this work, we present MinD-Vis: Sparse Masked Brain Modeling with Double-Conditioned Latent Diffusion Model for Human Vision Decoding. Specifically, by boosting the information capacity of feature representations learned from a large-scale resting-state fMRI dataset, we show that our MinD-Vis can reconstruct highly plausible images with semantically matching details from brain recordings with very few paired annotations. We benchmarked our model qualitatively and quantitatively; the experimental results indicate that our method outperformed state-of-the-art in both semantic mapping (100-way semantic classification) and generation quality (FID) by 66% and 41% respectively.
Our framework consists of two main stages:
- Stage A: Sparse-Coded Masked Brain Modeling (SC-MBM)
- Stage B: Double-Conditioned Latent Diffusion Model (DC-LDM)
The data folder and pretrains folder are not included in this repository. Please download them from FigShare and put them in the root directory of this repository as shown below.
File path | Description
┣ 📂 HCP
┃ ┣ 📂 npz
┃ ┃ ┣ 📂 dummy_sub_01
┃ ┃ ┃ ┗ HCP_visual_voxel.npz
┃ ┃ ┣ 📂 dummy_sub_02
┃ ┃ ┃ ┗ ...
┣ 📂 Kamitani
┃ ┣ 📂 npz
┃ ┃ ┗ 📜 sbj_1.npz
┃ ┃ ┗ 📜 sbj_2.npz
┃ ┃ ┗ 📜 sbj_3.npz
┃ ┃ ┗ 📜 sbj_4.npz
┃ ┃ ┗ 📜 sbj_5.npz
┃ ┃ ┗ 📜 images_256.npz
┃ ┃ ┗ 📜 imagenet_class_index.json
┃ ┃ ┗ 📜 imagenet_training_label.csv
┃ ┃ ┗ 📜 imagenet_testing_label.csv
┣ 📂 BOLD5000
┃ ┣ 📂 BOLD5000_GLMsingle_ROI_betas
┃ ┃ ┣ 📂 py
┃ ┃ ┃ ┗ CSI1_GLMbetas-TYPED-FITHRF-GLMDENOISE-RR_allses_LHEarlyVis.npy
┃ ┃ ┃ ┗ ...
┃ ┃ ┃ ┗ CSIx_GLMbetas-TYPED-FITHRF-GLMDENOISE-RR_allses_xx.npy
┃ ┣ 📂 BOLD5000_Stimuli
┃ ┃ ┣ 📂 Image_Labels
┃ ┃ ┣ 📂 Scene_Stimuli
┃ ┃ ┣ 📂 Stimuli_Presentation_Lists
┣ 📂 ldm
┃ ┣ 📂 label2img (ImageNet pre-trained label-conditioned LDM)
┃ ┃ ┗ 📜 config.yaml
┃ ┃ ┗ 📜 model.ckpt
┣ 📂 GOD
┃ ┗ 📜 fmri_encoder.pth (SC-MBM pre-trained fMRI encoder)
┃ ┗ 📜 finetuned.pth (finetuned fMRI encoder + finetuned LDM)
┣ 📂 BOLD5000
┃ ┗ 📜 fmri_encoder.pth (SC-MBM pre-trained fMRI encoder)
┃ ┗ 📜 finetuned.pth (finetuned fMRI encoder + finetuned LDM)
┣ 📂 sc_mbm
┃ ┗ 📜 mae_for_fmri.py
┃ ┗ 📜 trainer.py
┃ ┗ 📜 utils.py
┣ 📂 dc_ldm
┃ ┗ 📜 ldm_for_fmri.py
┃ ┗ 📜 utils.py
┃ ┣ 📂 models
┃ ┃ ┗ (adopted from LDM)
┃ ┣ 📂 modules
┃ ┃ ┗ (adopted from LDM)
┗ 📜 stageA1_mbm_pretrain.py (main script for pre-training for SC-MBM)
┗ 📜 stageA2_mbm_finetune.py (main script for tuning SC-MBM on fMRI only from test sets)
┗ 📜 stageB_ldm_finetune.py (main script for fine-tuning DC-LDM)
┗ 📜 gen_eval.py (main script for generating decoded images)
┗ 📜 dataset.py (functions for loading datasets)
┗ 📜 eval_metrics.py (functions for evaluation metrics)
┗ 📜 config.py (configurations for the main scripts)
Create and activate conda environment named mind-vis
from our env.yaml
conda env create -f env.yaml
conda activate mind-vis
Due to size limi and license issue, the full fMRI pre-training dataset (required to replicate Stage A) needs to be downloaded from the Human Connectome Projects (HCP) offical website. The pre-processing scripts are also included in this repo.
We also provide checkpoints and finetuning data at FigShare to run the finetuing and decoding directly. Due to the size limit, we only release the checkpoints for Subject 3 and CSI3 in the GOD and BOLD5000 respectively. Checkpoints for other subjects are also available upon request. After downloading, extract the data/
and pretrains/
to the project directory.
The fMRI pre-training is performed with masked brain modeling in the fMRI dataset containing around 136,000 fMRI samples from 1205 subjects (HCP + GOD). To perform the pre-training from scratch with defaults parameters, run
python code/stageA1_mbm_pretrain.py
Hyper-parameters can be changed with command line arguments,
python code/stageA1_mbm_pretrain.py --mask_ratio 0.65 --num_epoch 800 --batch_size 200
Or the parameters can also be changed in code/config.py
Multiple-GPU (DDP) training is supported, run with
python -m torch.distributed.launch --nproc_per_node=NUM_GPUS code/stageA1_mbm_pretrain.py
The pre-training results will be saved locally at results/fmri_pretrain
and remotely at wandb
After pre-training on the large-scale fMRI dataset, we need to finetune the autoencoder with fMRI data from the testing set. Run the following,
python code/stageA2_mbm_finetune.py --dataset GOD --pretrain_mbm_path results/fmri_pretrain/RUN_FOLDER_NAME/checkpoints/checkpoint.pth
can be either GOD
or BOLD5000
is the folder name generated for the pre-training. For example
python code/stageA2_mbm_finetune.py --dataset GOD --pretrain_mbm_path results/fmri_pretrain/01-08-2022-11:37:22/checkpoints/checkpoint.pth
The fMRI finetuning results will be saved locally at results/fmri_finetune
and remotely at wandb
In this stage, the cross-attention heads and pre-trained fMRI encoder will be jointly optimized with fMRI-image pairs. Decoded images will be generated in this stage. This stage can be run without downloading HCP. Only finetuning datasets and pre-trained fMRI encoder shared in our FigShare link are required. Run this stage with our provided pre-trained fMRI encoder and default parameters:
python code/stageB_ldm_finetune.py --dataset GOD
can be either GOD
or BOLD5000
. The results and generated samples will be saved locally at results/generation
and remotely at wandb
Run with custom-pre-trained fMRI encoder and parameters:
python code/stageB_ldm_finetune.py --dataset GOD --pretrain_mbm_path results/fmri_fintune/RUN_FOLDER_NAME/checkpoints/checkpoint.pth --num_epoch 500 --batch_size 5
Only finetuning datasets and trained checkpoints in our FigShare link are required. Run this stage with our provided checkpoints:
python code/gen_eval.py --dataset GOD
can be either GOD
or BOLD5000
. The results and generated samples will be saved locally at results/eval
and remotely at wandb
We thank Kamitani Lab and Weizmann Vision Lab for making their raw and pre-processed data public. We also thank BOLD5000 team for making their dataset public. Our Masked Brain Modeling implementation is based on the Masked Autoencoders by Facebook Research. Our Conditional Latent Diffusion Model implementation is based on the implementation from CompVis. We thank these authors for making their codes and checkpoints publicly available!