Train RetinaNet with Focal Loss in PyTorch.
Reference:
[1] Focal Loss for Dense Object Detection
-
datagen.py for ListDataset, based on torch.dataset, it uses the other files:
- transform.py: data transforms [eg. random_flip, random_crop, resize, etc.]
- utils.py: calculate box_iou, box_nms & some utils for displaying
- encoder.py: encode data to 9 anchor boxes.
-
train.py:
- fpn.py: Feature Pyramid Network definition & architecture. TODO: Review the code & refer to the paper
- loss.py: Focal loss definition
- retinanet.py: RetinaNet definition, also based on FPN50 inside fpn.py
-
test.py: test using the trained model & decode the results.
-
script/get_state_dict.py: use the pre-trained weights & generate FPN50 feature_extractor weights
the code is good, TODO: refer to the paper & review the code, understanding all the code snippets, and train you own data
20180808: train with Metro data and it seems ok. train with 1e-3