From 29dbf5efdb8ca2eba0dc7017bf90bed8a5e806d3 Mon Sep 17 00:00:00 2001 From: davda54 Date: Thu, 12 Nov 2020 18:02:36 +0100 Subject: [PATCH] cutout bugfix --- example/utility/cutout.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/example/utility/cutout.py b/example/utility/cutout.py index 98d5ca7..e7af491 100644 --- a/example/utility/cutout.py +++ b/example/utility/cutout.py @@ -8,12 +8,13 @@ def __init__(self, size=16, p=0.5): self.p = p def __call__(self, image): - if torch.rand([1]).item() > self.p: return image + if torch.rand([1]).item() > self.p: + return image - left = torch.randint(-self.half_size, image.shape[0] - self.half_size, [1]).item() - top = torch.randint(-self.half_size, image.shape[1] - self.half_size, [1]).item() - right = min(image.shape[0], left + self.size) - bottom = min(image.shape[1], top + self.size) + left = torch.randint(-self.half_size, image.size(1) - self.half_size, [1]).item() + top = torch.randint(-self.half_size, image.size(2) - self.half_size, [1]).item() + right = min(image.size(1), left + self.size) + bottom = min(image.size(2), top + self.size) - image[max(0,left):right, max(0,top):bottom, :] = 0 + image[:, max(0, left): right, max(0, top): bottom] = 0 return image