Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
add cutmix augmentation (#150)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #150

Reviewed By: mannatsingh, Tete-Xiao

Differential Revision: D29488389

Pulled By: pdollar

fbshipit-source-id: d180abbd7e622ca171c45e1be6d3cef8963b7178
  • Loading branch information
Tete Xiao authored and facebook-github-bot committed Jun 30, 2021
1 parent 5b57451 commit 39f2dc2
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
3 changes: 3 additions & 0 deletions pycls/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,9 @@
# Batch mixup regularization value in 0 to 1 (0 gives no mixup)
_C.TRAIN.MIXUP_ALPHA = 0.0

# Batch cutmix regularization value in 0 to 1 (0 gives no cutmix)
_C.TRAIN.CUTMIX_ALPHA = 0.0

# Standard deviation for AlexNet-style PCA jitter (0 gives no PCA jitter)
_C.TRAIN.PCA_STD = 0.1

Expand Down
25 changes: 21 additions & 4 deletions pycls/core/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,31 @@ def forward(self, x, y):


def mixup(inputs, labels):
"""Apply mixup to minibatch (https://arxiv.org/abs/1710.09412)."""
alpha = cfg.TRAIN.MIXUP_ALPHA
"""
Apply mixup or cutmix to minibatch depending MIXUP_ALPHA and CUTMIX_ALPHA.
IF MIXUP_ALPHA > 0, applies mixup (https://arxiv.org/abs/1710.09412).
IF CUTMIX_ALPHA > 0, applies cutmix (https://arxiv.org/abs/1905.04899).
If both MIXUP_ALPHA > 0 and CUTMIX_ALPHA > 0, 50-50 chance of which is applied.
"""
assert labels.shape[1] == cfg.MODEL.NUM_CLASSES, "mixup labels must be one-hot"
if alpha > 0:
m = np.random.beta(alpha, alpha)
mixup_alpha, cutmix_alpha = cfg.TRAIN.MIXUP_ALPHA, cfg.TRAIN.CUTMIX_ALPHA
mixup_alpha = mixup_alpha if (cutmix_alpha == 0 or np.random.rand() < 0.5) else 0
if mixup_alpha > 0:
m = np.random.beta(mixup_alpha, mixup_alpha)
permutation = torch.randperm(labels.shape[0])
inputs = m * inputs + (1.0 - m) * inputs[permutation, :]
labels = m * labels + (1.0 - m) * labels[permutation, :]
elif cutmix_alpha > 0:
m = np.random.beta(cutmix_alpha, cutmix_alpha)
permutation = torch.randperm(labels.shape[0])
h, w = inputs.shape[2], inputs.shape[3]
w_b, h_b = np.int(w * np.sqrt(1.0 - m)), np.int(h * np.sqrt(1.0 - m))
x_c, y_c = np.random.randint(w), np.random.randint(h)
x_0, y_0 = np.clip(x_c - w_b // 2, 0, w), np.clip(y_c - h_b // 2, 0, h)
x_1, y_1 = np.clip(x_c + w_b // 2, 0, w), np.clip(y_c + h_b // 2, 0, h)
m = 1.0 - ((x_1 - x_0) * (y_1 - y_0) / (h * w))
inputs[:, :, y_0:y_1, x_0:x_1] = inputs[permutation, :, y_0:y_1, x_0:x_1]
labels = m * labels + (1.0 - m) * labels[permutation, :]
return inputs, labels, labels.argmax(1)


Expand Down

0 comments on commit 39f2dc2

Please sign in to comment.