Skip to content

Commit

Permalink
adding scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
Abhay Gupta committed Oct 7, 2020
1 parent 336e9d2 commit 89b2f08
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 2 deletions.
34 changes: 34 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,44 @@ To Do:
- [ ] Correct hyper parameters
- [ ] Full Axial-ViT

## Installation

```bash
pip install tensorboardX
mkdir data
cd data
ln -s path/to/dataset imagenet
```

## Running the Scripts

For non-distributed training:

```bash
python train.py --model ViT --name vit_logs
```

For distributed training:

```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 python dist_train.py --model ViT --name vit_dist_logs
```

For testing add the `--test` parameter:

```bash
python train.py --model ViT --name vit_logs --test
```

```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 python dist_train.py --model ViT --name vit_dist_logs --test
```

## References

1. [BiTResNet](https://github.com/google-research/big_transfer/tree/master/bit_pytorch)
2. [AxialResNet](https://github.com/csrhddlam/axial-deeplab)
3. [Training Scripts](https://github.com/csrhddlam/axial-deeplab)

## Citations

Expand Down
2 changes: 1 addition & 1 deletion dist_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def parse_args():
default=False,
help="To apply color augmentation or not.",
)
parser.add_argument('--model', default='axial50s', help='Model names.')
parser.add_argument('--model', default='ViT', help='Model names.')
parser.add_argument(
'--epochs', type=int, default=130, help='number of epochs to train'
)
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def parse_args():
default=False,
help="To apply color augmentation or not.",
)
parser.add_argument('--model', default='axial50s', help='Model names.')
parser.add_argument('--model', default='ViT', help='Model names.')
parser.add_argument(
'--epochs', type=int, default=130, help='number of epochs to train'
)
Expand Down

0 comments on commit 89b2f08

Please sign in to comment.