Source code for our paper (link forthcoming) defining a pre-training benchmark system for EHR timeseries data. Contact [email protected] and [email protected] with any questions. Pending interest from the community, we're eager to make this as usable as possible, and will respond promptly to any issues or questions.
Set up the repository
conda env create --name comprehensive_EHR_PT -f env.yml
conda activate comprehensive_EHR_PT
Copies of pre-processed dataset splits used in the paper can be obtained via Google Cloud. To access them, you must ensure that you have obtained GCP access via physionet.org for the requisite datasets. See https://mimic.physionet.org/gettingstarted/cloud/ for instructions on obtaining Physionet GCP access.
- MIMIC-III Dataset: https://console.cloud.google.com/storage/browser/ehr_pretraining_benchmark_mimic
- eICU Dataset: https://console.cloud.google.com/storage/browser/ehr_pretraining_benchmark_eicu
Arguments for all scripts are described in the latent_patient_trajectories/representation_learner/args.py
file. This file has some base classes, then argument classes (with specific args requested) for all functions.
It is a good reference to determine what a specific script expects. Note this class allows you to (and we
recommend) pre-setting all args for scripts in (appropriately named) json files in the relevant experiment
directories, then simply passing the directory to the given script (according to the appropriate arg
) and
adding --do_load_from_dir
, at which point the script will load all arguments from the json file
automatically. Note that some args (specifically regression_task_weight
, which should always be 0, notes
, which
should always be no_notes
, and task_weights_filepath
, which should always be ''
) are held-out args from older versions of the code, and can be largely ignored. Similarly, the modeltype specific args corresponding to CNN, Self-attention, or Linear projection models are also no longer used. Some sample args for different settings are given in Sample Args
. Please raise a github issue or contact [email protected] or [email protected] with any questions.
To perform hyperparameter tuning, set up a base experimental directory, and add a config file describing your
hyperparameter search in that directory. This file must be named according to the HYP_CONFIG_FILENAME
constant in latent_patient_trajectories/constants.py
file, which is (as of 7/20/20) set to
hyperparameter_search_config.json
. A sample config file is shown in the file
latent_patient_trajectories/representation_learner/sample_hyperopt_config.json
.
Then, run the script Scripts/End to End/hyperopt_model.py
(with appropriate args, as described in the
args.py
file referenced above under the class HyperparameterSearchArgs
) to kick off your search.
To perform a generic run, training a multi- or single-task model, or a masked imputation model, use the Args
class in args.py
and the run_model.py
script. As with everything else, you will need to specify a base
directory, and many other args to describe the architecture you want to use and training details.
Evaluating a pre-trained run can be accomplished with the EvalArgs
class and the evaluate.py
script. You
will need to specify the model's training directory (e.g., the directory passed to run_model.py
) so the
script knows what model to reload.
To convert evaluation results into a form that is human readable and aggregated across tasks, use the get_manuscript_metrics*
functions (e.g., https://github.com/mmcdermott/comprehensive_MTL_EHR/blob/master/latent_patient_trajectories/representation_learner/evaluator.py#L41) on the output dictionaries. This function just re-processes the more granular output of evaluate.py
into a more readable form.
To see an example of where that function is called you can look within the hyperparameter tuning code here:https://github.com/mmcdermott/comprehensive_MTL_EHR/blob/master/latent_patient_trajectories/representation_learner/hyperparameter_search.py#L724 then here:https://github.com/mmcdermott/comprehensive_MTL_EHR/blob/master/latent_patient_trajectories/representation_learner/hyperparameter_search.py#L664
That code bit shows where the output from the evaluator main function can be parsed into the expected input to the get_manuscript_metrics_via_args
function.
These runs consist of pre-training a model either via masked imputation or via multi-task pre-training, in
which the model is pre-trained on all tasks except for one held-out task, then fine-tuning the
model on that held-out task, and evaluating both the pre-trained and fine-tuned models on all tasks. This
could be done manually, through repeated use of the run_model.py
script and the use of the --ablate
arg,
but we have a helper script that can manage doing all requisite runs across multiple GPUs in parallel on a
single-machine. To use this, you must first create a base directory for this experiment, which will ultimately
hold all runs associated with this experiment (including pre-trained and fine=tuned). In this directory, you
will specify the model's base args (which will be duplicated and used in all pre-training and fine-tuning
experiments, with the ablate arg automatically adjusted to perform the appropriate experiments) as a json
file which is parseable by the Args
class (note that when run generally, models will write such a file to
disk in their directory, so you can just copy and paste the file from the model you want to examine), as well
as a configuration describing which GPUs you have available on the system and how many models you want to run
on each GPU at a given time, and how many GPUs each model needs (usually both of the latter are 1). There is a
sample config available in
latent_patient_trajectories/representation_learner/sample_task_generalizability_config.json
, and note that
the config must be renamed to task_generalizability_config.json
for actual use. You additionally can specify
args according to the TaskGeneralizabilityArgs
class in the args.py
file.
All our results can be analyzed via the All Results.ipynb
notebook. Input files for this notebook are
available upon request -- as they aggregate both within MIMIC-III and eICU, we cannot use GCP so instead must
validate your physionet access directly.