Skip to content

Commit

Permalink
also split sem_seg_head into layers & loss
Browse files Browse the repository at this point in the history
Summary: The similar split we did for mask/keypoint makes sense for sem_seg_head as well.

Reviewed By: rbgirshick, alexander-kirillov

Differential Revision: D20365822

fbshipit-source-id: c4eaf36a4d685e924ede50d473dba910facc8fd2
  • Loading branch information
ppwwyyxx authored and facebook-github-bot committed Mar 10, 2020
1 parent a4450de commit 408c1cf
Show file tree
Hide file tree
Showing 12 changed files with 81 additions and 38 deletions.
8 changes: 5 additions & 3 deletions .github/ISSUE_TEMPLATE/unexpected-problems-bugs.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ post according to this template:
<put diff or code here>
```
2. what exact command you run:
3. what you observed (including the full logs):
3. what you observed (including __full logs__):
```
<put logs here>
```
Expand All @@ -35,8 +35,10 @@ Only in one of the two conditions we will help with it:

## Environment:

Run `python -m detectron2.utils.collect_env` in the environment where you observerd the issue, and paste the output.
If detectron2 hasn't been successfully installed, use `python detectron2/utils/collect_env.py` (after getting this file from github).
Provide your environment information using the following command:
```
wget -nc -q https://github.com/facebookresearch/detectron2/raw/master/detectron2/utils/collect_env.py && python collect_env.py
```

If your issue looks like an installation issue / environment issue,
please first try to solve it yourself with the instructions in
Expand Down
7 changes: 5 additions & 2 deletions INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ also installs detectron2 with a few simple commands.
- [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation.
You can install them together at [pytorch.org](https://pytorch.org) to make sure of this.
- OpenCV, optional, needed by demo and visualization
- pycocotools: `pip install cython; pip install 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'`
- pycocotools: `pip install cython; pip install -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'`


### Build Detectron2 from Source

After having the above dependencies and gcc & g++ ≥ 4.9, run:
After having the above dependencies and gcc & g++ ≥ 5, run:
```
python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'
# (add --user if you don't have permission)
Expand Down Expand Up @@ -154,6 +154,9 @@ to match your local CUDA installation, or install a different version of CUDA to
</summary>
<br/>
Please build and install detectron2 following the instructions above.

If you are running code from detectron2's root directory, `cd` to a different one.
Otherwise you may not import the code that you installed.
</details>

<details>
Expand Down
6 changes: 3 additions & 3 deletions MODEL_ZOO.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ In addition to these official baseline models, you can find more models in [proj

We provide backbone models pretrained on ImageNet-1k dataset.
These models have __different__ format from those provided in Detectron: we do not fuse BatchNorm into an affine layer.
* [R-50.pkl](https://dl.fbaipublicfiles.com/detectron2/ImageNetPretrained/MSRA/R-50.pkl): converted copy of MSRA's original ResNet-50 model
* [R-101.pkl](https://dl.fbaipublicfiles.com/detectron2/ImageNetPretrained/MSRA/R-101.pkl): converted copy of MSRA's original ResNet-101 model
* [X-101-32x8d.pkl](https://dl.fbaipublicfiles.com/detectron2/ImageNetPretrained/FAIR/X-101-32x8d.pkl): ResNeXt-101-32x8d model trained with Caffe2 at FB
* [R-50.pkl](https://dl.fbaipublicfiles.com/detectron2/ImageNetPretrained/MSRA/R-50.pkl): converted copy of [MSRA's original ResNet-50](https://github.com/KaimingHe/deep-residual-networks) model.
* [R-101.pkl](https://dl.fbaipublicfiles.com/detectron2/ImageNetPretrained/MSRA/R-101.pkl): converted copy of [MSRA's original ResNet-101](https://github.com/KaimingHe/deep-residual-networks) model.
* [X-101-32x8d.pkl](https://dl.fbaipublicfiles.com/detectron2/ImageNetPretrained/FAIR/X-101-32x8d.pkl): ResNeXt-101-32x8d model trained with Caffe2 at FB.

Pretrained models in Detectron's format can still be used. For example:
* [X-152-32x8d-IN5k.pkl](https://dl.fbaipublicfiles.com/detectron/ImageNetPretrained/25093814/X-152-32x8d-IN5k.pkl):
Expand Down
37 changes: 26 additions & 11 deletions detectron2/modeling/meta_arch/semantic_seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,20 +153,35 @@ def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]):
weight_init.c2_msra_fill(self.predictor)

def forward(self, features, targets=None):
"""
Returns:
In training, returns (None, dict of losses)
In inference, returns (predictions, {})
"""
x = self.layers(features)
if self.training:
return None, self.losses(x, targets)
else:
x = F.interpolate(
x, scale_factor=self.common_stride, mode="bilinear", align_corners=False
)
return x, {}

def layers(self, features):
for i, f in enumerate(self.in_features):
if i == 0:
x = self.scale_heads[i](features[f])
else:
x = x + self.scale_heads[i](features[f])
x = self.predictor(x)
x = F.interpolate(x, scale_factor=self.common_stride, mode="bilinear", align_corners=False)

if self.training:
losses = {}
losses["loss_sem_seg"] = (
F.cross_entropy(x, targets, reduction="mean", ignore_index=self.ignore_value)
* self.loss_weight
)
return [], losses
else:
return x, {}
return x

def losses(self, predictions, targets):
predictions = F.interpolate(
predictions, scale_factor=self.common_stride, mode="bilinear", align_corners=False
)
loss = F.cross_entropy(
predictions, targets, reduction="mean", ignore_index=self.ignore_value
)
losses = {"loss_sem_seg": loss * self.loss_weight}
return losses
2 changes: 1 addition & 1 deletion detectron2/structures/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class Boxes:
(support indexing, `to(device)`, `.device`, and iteration over all boxes)
Attributes:
tensor (torch.Tensor): float matrix of Nx4.
tensor (torch.Tensor): float matrix of Nx4. Each row is (x1, y1, x2, y2).
"""

BoxSizeType = Union[List[int], Tuple[int, int]]
Expand Down
36 changes: 24 additions & 12 deletions detectron2/utils/collect_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,25 @@ def collect_env_info():
data.append(("numpy", np.__version__))

try:
import detectron2
from detectron2 import _C
except ImportError:
data.append(("detectron2._C", "failed to import"))
else:
import detectron2 # noqa

data.append(
("detectron2", detectron2.__version__ + " @" + os.path.dirname(detectron2.__file__))
)
data.append(("detectron2 compiler", _C.get_compiler_version()))
data.append(("detectron2 CUDA compiler", _C.get_cuda_version()))
if has_cuda:
data.append(
("detectron2 arch flags", detect_compute_compatibility(CUDA_HOME, _C.__file__))
)
except ImportError:
data.append(("detectron2", "failed to import"))
else:
try:
from detectron2 import _C
except ImportError:
data.append(("detectron2._C", "failed to import"))
else:
data.append(("detectron2 compiler", _C.get_compiler_version()))
data.append(("detectron2 CUDA compiler", _C.get_cuda_version()))
if has_cuda:
data.append(
("detectron2 arch flags", detect_compute_compatibility(CUDA_HOME, _C.__file__))
)

data.append(get_env_module())
data.append(("PyTorch", torch.__version__ + " @" + os.path.dirname(torch.__file__)))
Expand Down Expand Up @@ -138,4 +143,11 @@ def collect_env_info():


if __name__ == "__main__":
print(collect_env_info())
try:
import detectron2 # noqa
except ImportError:
print(collect_env_info())
else:
from detectron2.utils.collect_env import collect_env_info

print(collect_env_info())
18 changes: 12 additions & 6 deletions docs/tutorials/training.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,21 @@ You can use
which provides minimal abstraction for single-cost single-optimizer single-data-source training.
The builtin `train_net.py` script uses
[DefaultTrainer().train()](../modules/engine.html#detectron2.engine.defaults.DefaultTrainer),
which includes more standard default behavior that one might want to opt in.
which includes more standard default behavior that one might want to opt in,
including default configurations for logging, evaluation, checkpointing etc.
This also means that it's less likely to support some non-standard behavior
you might want during research.

To customize the training loops, you can either start
from [tools/plain_train_net.py](../../tools/plain_train_net.py),
or look at the source code of [DefaultTrainer](../../detectron2/engine/defaults.py)
and overwrite some of its behaviors with new parameters or new hooks.
To customize the training loops, you can:

1. If your customization is similar to what `DefaultTrainer` is already doing,
you can look at the source code of [DefaultTrainer](../../detectron2/engine/defaults.py)
and overwrite some of its behaviors with new parameters or new hooks.
2. If you need something very novel, you can start from [tools/plain_train_net.py](../../tools/plain_train_net.py) to implement them yourself.

### Logging of Metrics

During training, metrics are logged with a centralized [EventStorage](../modules/utils.html#detectron2.utils.events.EventStorage).
During training, metrics are saved to a centralized [EventStorage](../modules/utils.html#detectron2.utils.events.EventStorage).
You can use the following code to access it and log metrics to it:
```
from detectron2.utils.events import get_event_storage
Expand All @@ -41,3 +43,7 @@ if self.training:
```

Refer to its documentation for more details.

Metrics are then saved to various destinations with [EventWriter](../modules/utils.html#module-detectron2.utils.events).
DefaultTrainer enables a few `EventWriter` with default configurations.
See above for how to customize them.
1 change: 1 addition & 0 deletions tools/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
A script to benchmark builtin models.
Expand Down
1 change: 1 addition & 0 deletions tools/caffe2_converter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import argparse
import os
Expand Down
1 change: 1 addition & 0 deletions tools/plain_train_net.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Detectron2 training script with a plain training loop.
Expand Down
1 change: 1 addition & 0 deletions tools/train_net.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Detection Training Script.
Expand Down
1 change: 1 addition & 0 deletions tools/visualize_data.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import argparse
import numpy as np
Expand Down

0 comments on commit 408c1cf

Please sign in to comment.