forked from sergeyvilov/DeepSom
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
sergey.vilov
committed
Feb 25, 2022
0 parents
commit 7531cfd
Showing
27 changed files
with
10,790 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# Neural Network Classifier. | ||
|
||
## Setup | ||
|
||
|
||
Install [miniconda](https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html). | ||
|
||
Then run: | ||
|
||
``` | ||
conda env create -f environment.yml | ||
conda activate nnc | ||
``` | ||
|
||
## Train mode. | ||
|
||
An example run of the NNC in train mode: | ||
|
||
``` | ||
python nn.py \ | ||
--negative_ds './datasets/dataset_name/tensors/train/train_neg_imgb.lst' \ | ||
--positive_ds './datasets/dataset_name/tensors/train/train_pos_imgb.lst' \ | ||
--output_dir './datasets/dataset_name/checkpoints' \ | ||
--tensor_width 150 \ | ||
--tensor_height 70 \ | ||
--val_fraction 0.0 \ | ||
--resample_train 'upsampling' \ | ||
--save_each 5 | ||
``` | ||
|
||
`--tensor_width` --- width of the variant tensor | ||
|
||
`--tensor_height` --- maximal height of the variant tensor. Variant tensors with more reads will be cropped. | ||
|
||
`--output_dir` --- folder to save model and optimizer weights as well as NNC prediction scores computed on the validation (test) set. | ||
|
||
`--val_fraction` --- percentage of input variants used for validation and not for training. After each training epoch, the NNC performance is evaluated on the validation set. | ||
|
||
`--resample_train` --- use 'upsample' to create a balanced train set by upsampling the class with the lower number of variants, 'downsample' to create a balanced train set by downsampling the class with the higher number of variants, 'None' if resampling of the train set is not needed. | ||
|
||
`--save_each` --- how often model and optimizer weights should be saved on the disk. | ||
|
||
See `python nn.py --help` for more training options. | ||
|
||
Training for 70 000 positive-class and 70 000 negative-class tensors (`tensor_width`=150; `tensor_height`=70) takes about 20h on NVIDIA Tesla 100V. | ||
|
||
## Evaluation (test) mode. | ||
|
||
An example run of the NNC in evaluation (test) mode: | ||
|
||
``` | ||
python nn.py \ | ||
--negative_ds './datasets/dataset_name/tensors/test/test_neg_imgb.lst' \ | ||
--positive_ds './datasets/dataset_name/tensors/test/test_pos_imgb.lst' \ | ||
--output_dir './datasets/dataset_name/test' \ | ||
--load_weights 1 \ | ||
--config_start_base './datasets/dataset_name/checkpoints/epoch_20_weights' \ | ||
--tensor_width 150 \ | ||
--tensor_height 70 \ | ||
--val_fraction 1.0 | ||
``` | ||
So, it looks like the train mode, but all the tensors are used for evaluation (`--val_fraction 1.0`). | ||
Parameters `--load_weights` and `--config_start_base` are used to load the weights of a pretrained model to perform evaluation. | ||
|
||
## Inference mode. | ||
|
||
An example run of the NNC in inference mode: | ||
|
||
``` | ||
python nn.py \ | ||
--inference_ds './projects/project_name/to_classify/tensors/inference_imgb.lst' \ | ||
--inference_mode 1 \ | ||
--output_dir './inference/projects/project_name/to_classify/dataset_name' \ | ||
--load_weights 1 \ | ||
--config_start_base './datasets/dataset_name/checkpoints/epoch_20_weights' \ | ||
--tensor_width 150 \ | ||
--tensor_height 70 | ||
``` | ||
|
||
NNC output scores for variants in `--inference_ds` will be written to an `inference.csv` file in `--output_dir`. | ||
Each line of `inference.csv` corresponds to a single SNP variant. Use `variants.csv.gz` from the dataset tensors dir to identify | ||
variants in `--inference_ds`. | ||
|
||
# Using NNC output to compute probabilities and classify variants | ||
|
||
For any given variant, the NNC outputs a continuous score ![equation](https://latex.codecogs.com/svg.image?s) | ||
between 0 and 1. The higher ![equation](https://latex.codecogs.com/svg.image?s) | ||
, the higher the probability ![equation](https://latex.codecogs.com/svg.image?p_%7Bsom%7D) that the variant is somatic. The relation between ![equation](https://latex.codecogs.com/svg.image?s) and ![equation](https://latex.codecogs.com/svg.image?p_%7Bsom%7D) is non-linear. | ||
|
||
To compute ![equation](https://latex.codecogs.com/svg.image?p_%7Bsom%7D) based on ![equation](https://latex.codecogs.com/svg.image?s), one needs to calibrate the NNC output. For calibration, one runs the pre-trained NNC on test SNPs. Then, the NNC output values are binned s.t. there are at least 20 values per bin. The probability that a variant whose score ends up in bin ![equation](https://latex.codecogs.com/svg.image?s_%7Bi%7D) is somatic is given by the Bayes formula: | ||
|
||
![equation](https://latex.codecogs.com/svg.image?p_%7Bsom%7D(s%5Csubset%20s_i)=%5Cfrac%7BP(s%5Csubset%20s_i%7Csom)%5Ctimes%20N_%7Bsom%7D%7D%7BP(s%5Csubset%20s_i%7Csom)%5Ctimes%20N_%7Bsom%7D%20+%20P(s%5Csubset%20s_i%7Cneg)%5Ctimes%20N_%7Bneg%7D%7D) | ||
|
||
where ![equation](https://latex.codecogs.com/svg.image?N_%7Bsom%7D) is the number of somatic variants per WGS sample, | ||
![equation](https://latex.codecogs.com/svg.image?N_%7Bneg%7D) is the number of germline variants and artefacts per WGS sample, ![equation](https://latex.codecogs.com/svg.image?P(s%5Csubset%20s_i%7Csom)%20) is the fraction of true somatic variants at the input that end up in bin ![equation](https://latex.codecogs.com/svg.image?s_%7Bi%7D), ![equation](https://latex.codecogs.com/svg.image?P(s%5Csubset%20s_i%7Cneg)%20) is the fraction of true germline variants and artefacts at the input that end up in bin ![equation](https://latex.codecogs.com/svg.image?s_%7Bi%7D). | ||
|
||
|
||
Variant classification is performed by imposing a threshold on ![equation](https://latex.codecogs.com/svg.image?s) s.t. all variants with ![equation](https://latex.codecogs.com/svg.image?s%3Es_%7Bthr%7D) are considered somatic. This threshold can be chosen based on the corresponding probability ![equation](https://latex.codecogs.com/svg.image?p_%7Bsom%7D). Alternatively, ![equation](https://latex.codecogs.com/svg.image?s_%7Bthr%7D) can be chosen based on the ROC curve. The ROC curve is a plot of the true positive rate (TPR) against the false positive rate (FPR) w.r.t. somatic variants. When ![equation](https://latex.codecogs.com/svg.image?s_%7Bthr%7D) increases the operating point on the ROC curve moves from the upper right to the lower left quadrant. In any case, variants used for choosing ![equation](https://latex.codecogs.com/svg.image?s_%7Bthr%7D) should be different from those used for NNC training. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
name: nnc | ||
channels: | ||
- pytorch | ||
- anaconda | ||
- defaults | ||
dependencies: | ||
- _libgcc_mutex=0.1=main | ||
- _openmp_mutex=4.5=1_gnu | ||
- blas=1.0=mkl | ||
- bottleneck=1.3.2=py39hdd57654_1 | ||
- ca-certificates=2020.10.14=0 | ||
- certifi=2021.10.8=py39h06a4308_2 | ||
- cudatoolkit=11.3.1=h2bc3f7f_2 | ||
- intel-openmp=2021.4.0=h06a4308_3561 | ||
- joblib=0.17.0=py_0 | ||
- ld_impl_linux-64=2.35.1=h7274673_9 | ||
- libffi=3.3=he6710b0_2 | ||
- libgcc-ng=9.3.0=h5101ec6_17 | ||
- libgfortran-ng=7.5.0=ha8ba4b0_17 | ||
- libgfortran4=7.5.0=ha8ba4b0_17 | ||
- libgomp=9.3.0=h5101ec6_17 | ||
- libstdcxx-ng=9.3.0=hd4cf53a_17 | ||
- libuv=1.40.0=h7b6447c_0 | ||
- mkl=2021.4.0=h06a4308_640 | ||
- mkl-service=2.4.0=py39h7f8727e_0 | ||
- mkl_fft=1.3.1=py39hd3c417c_0 | ||
- mkl_random=1.2.2=py39h51133e4_0 | ||
- ncurses=6.3=h7f8727e_2 | ||
- numexpr=2.8.1=py39h6abb31d_0 | ||
- numpy=1.21.2=py39h20f2e39_0 | ||
- numpy-base=1.21.2=py39h79a1101_0 | ||
- openssl=1.1.1m=h7f8727e_0 | ||
- packaging=21.3=pyhd3eb1b0_0 | ||
- pandas=1.4.1=py39h295c915_0 | ||
- pip=21.2.4=py39h06a4308_0 | ||
- pyparsing=3.0.4=pyhd3eb1b0_0 | ||
- python=3.9.7=h12debd9_1 | ||
- python-dateutil=2.8.2=pyhd3eb1b0_0 | ||
- pytorch=1.10.2=py3.9_cuda11.3_cudnn8.2.0_0 | ||
- pytorch-mutex=1.0=cuda | ||
- pytz=2021.3=pyhd3eb1b0_0 | ||
- readline=8.1.2=h7f8727e_1 | ||
- scikit-learn=1.0.2=py39h51133e4_1 | ||
- scipy=1.7.3=py39hc147768_0 | ||
- setuptools=58.0.4=py39h06a4308_0 | ||
- six=1.16.0=pyhd3eb1b0_1 | ||
- sqlite=3.37.2=hc218d9a_0 | ||
- threadpoolctl=2.1.0=pyh5ca1d4c_0 | ||
- tk=8.6.11=h1ccaba5_0 | ||
- typing_extensions=3.10.0.2=pyh06a4308_0 | ||
- tzdata=2021e=hda174b7_0 | ||
- wheel=0.37.1=pyhd3eb1b0_0 | ||
- xz=5.2.5=h7b6447c_0 | ||
- zlib=1.2.11=h7f8727e_4 | ||
prefix: /home/icb/sergey.vilov/miniconda3/envs/nnc |
Oops, something went wrong.