diff --git a/pycls/core/config.py b/pycls/core/config.py index 1f7dc80..1245fd6 100644 --- a/pycls/core/config.py +++ b/pycls/core/config.py @@ -265,6 +265,9 @@ # Batch mixup regularization value in 0 to 1 (0 gives no mixup) _C.TRAIN.MIXUP_ALPHA = 0.0 +# Batch cutmix regularization value in 0 to 1 (0 gives no cutmix) +_C.TRAIN.CUTMIX_ALPHA = 0.0 + # Standard deviation for AlexNet-style PCA jitter (0 gives no PCA jitter) _C.TRAIN.PCA_STD = 0.1 diff --git a/pycls/core/net.py b/pycls/core/net.py index 3c1e44f..8d62013 100644 --- a/pycls/core/net.py +++ b/pycls/core/net.py @@ -87,14 +87,31 @@ def forward(self, x, y): def mixup(inputs, labels): - """Apply mixup to minibatch (https://arxiv.org/abs/1710.09412).""" - alpha = cfg.TRAIN.MIXUP_ALPHA + """ + Apply mixup or cutmix to minibatch depending MIXUP_ALPHA and CUTMIX_ALPHA. + IF MIXUP_ALPHA > 0, applies mixup (https://arxiv.org/abs/1710.09412). + IF CUTMIX_ALPHA > 0, applies cutmix (https://arxiv.org/abs/1905.04899). + If both MIXUP_ALPHA > 0 and CUTMIX_ALPHA > 0, 50-50 chance of which is applied. + """ assert labels.shape[1] == cfg.MODEL.NUM_CLASSES, "mixup labels must be one-hot" - if alpha > 0: - m = np.random.beta(alpha, alpha) + mixup_alpha, cutmix_alpha = cfg.TRAIN.MIXUP_ALPHA, cfg.TRAIN.CUTMIX_ALPHA + mixup_alpha = mixup_alpha if (cutmix_alpha == 0 or np.random.rand() < 0.5) else 0 + if mixup_alpha > 0: + m = np.random.beta(mixup_alpha, mixup_alpha) permutation = torch.randperm(labels.shape[0]) inputs = m * inputs + (1.0 - m) * inputs[permutation, :] labels = m * labels + (1.0 - m) * labels[permutation, :] + elif cutmix_alpha > 0: + m = np.random.beta(cutmix_alpha, cutmix_alpha) + permutation = torch.randperm(labels.shape[0]) + h, w = inputs.shape[2], inputs.shape[3] + w_b, h_b = np.int(w * np.sqrt(1.0 - m)), np.int(h * np.sqrt(1.0 - m)) + x_c, y_c = np.random.randint(w), np.random.randint(h) + x_0, y_0 = np.clip(x_c - w_b // 2, 0, w), np.clip(y_c - h_b // 2, 0, h) + x_1, y_1 = np.clip(x_c + w_b // 2, 0, w), np.clip(y_c + h_b // 2, 0, h) + m = 1.0 - ((x_1 - x_0) * (y_1 - y_0) / (h * w)) + inputs[:, :, y_0:y_1, x_0:x_1] = inputs[permutation, :, y_0:y_1, x_0:x_1] + labels = m * labels + (1.0 - m) * labels[permutation, :] return inputs, labels, labels.argmax(1)