Visualize what you learn: A well-explainable joint-learning framework based on multi-view mammograms and associated reports
In this paper, we introduce a novel pre-training framework for label-efficient medical image recognition, which we refer to as the "multiview-allowed radiograph joint exam-level report" approach. Our proposed strategy "visualize what you learn" is designed to provide a comprehensive and easily interpretable visualization of the visual and textual features learned by deep learning models, thereby enabling developers to assess the depth of the model's understanding beyond its performance.
We evaluate the performance of our framework on various medical imaging datasets, including classification, segmentation, and localization tasks in both fine-tuning and zero-shot settings. Our results demonstrate that our proposed approach achieves high performance and label efficiency compared to existing state-of-the-art methods. Overall, our approach offers a promising direction for developing more robust and effective medical image recognition systems.
Start by installing PyTorch 1.7.1 with the right CUDA version, then clone this repository and install the dependencies.
$ conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=10.1 -c pytorch
$ pip install [email protected]:yawwG/Visualize-what-you-learn.git
$ conda env create -f environment.yml
Make sure to download the pretrained weights from here(it will be publicly availible soon!) and place it in the ./pretrained
folder.
def load_VSWL(
name: str = "VSWL_resnet50",
device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
):
"""Load a VSWL model
Parameters
----------
name : str
A model name listed by `VSWL2.available_models()`, or the path to a model checkpoint containing the state_dict
device : Union[str, torch.device]
The device to put the loaded model
Returns
-------
VSWL_model : torch.nn.Module
The VSWL model
"""
# warnings
if name in _MODELS:
ckpt_path = _MODELS[name]
elif os.path.isfile(name):
ckpt_path = name
else:
raise RuntimeError(
f"Model {name} not found; available models = {available_models()}"
)
ckpt = torch.load(ckpt_path)
# ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
cfg = ckpt["hyper_parameters"]
ckpt_dict = ckpt["state_dict"]
fixed_ckpt_dict = {}
for k, v in ckpt_dict.items():
new_key = k.split("VSWL2.")[-1]
fixed_ckpt_dict[new_key] = v
ckpt_dict = fixed_ckpt_dict
VSWL_model = builder.build_VSWL_model(cfg).to(device)
VSWL_model.load_state_dict(ckpt_dict)
return VSWL_model
python zeroshot.py
#Check more details from VSML.py including definations of zeroshot applications model.
This codebase has been developed with python version 3.7, PyTorch version 1.7.1, CUDA 10.2 and pytorch-lightning 1.1.4.
Example configurations for pretraining and downstream classification can be found in the ./configs
. All training and testing are done using the run.py
script. For more documentation, please run:
python run.py --help
The preprocessing steps for each dataset can be found in preprocess_datasets.py
Train the representation learning model with the following command:
python run.py -c pretrain_config.yaml --train
Fine-tune the A2I2 pretrained image model for classification with the following command:
python run.py -c configs/classification_config_1.yaml --train --test --train_pct 1 &
python run.py -c configs/classification_config_0.1.yaml --train --test --train_pct 0.1 &
python run.py -c configs/classification_config_0.01.yaml --train --test --train_pct 0.01
The train_pct flag randomly selects a percentage of the dataset to fine-tune the model. This is use to determine the performance of the model under low data regime. The dataset using is specified in config.yaml by key("dataset").
Fine-tune the A2I2 pretrained image model for segmentation/localization with the following command:
python run.py -c configs/segmentation_config_1.yaml --train --test --train_pct 1 &
python run.py -c configs/segmentation_config_0.1.yaml --train --test --train_pct 0.1 &
python run.py -c configs/segmentation_config_0.01.yaml --train --test --train_pct 0.01
If you have any questions please contact us.
Email: [email protected] (Ritse Mann); [email protected] (Tao Tan); [email protected] (Yuan Gao)
Links: Netherlands Cancer Institute, Radboud University Medical Center and Maastricht University and The University of Hong Kong