Skip to content

Commit

Permalink
Add support in CropMirrorNormalize for uneven sizes of mean and std (…
Browse files Browse the repository at this point in the history
…#1708)

* Add support in CropMirrorNormalize for uneven sizes of mean and std

Signed-off-by: Joaquin Anton <[email protected]>
  • Loading branch information
jantonguirao authored Feb 4, 2020
1 parent ab36ff5 commit 7975e9c
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
19 changes: 19 additions & 0 deletions dali/operators/image/crop/crop_mirror_normalize.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,29 @@ class CropMirrorNormalize : public Operator<Backend>, protected CropAttr {
inv_std_vec_ = { spec.GetArgument<float>("std") };
}

DALI_ENFORCE(!mean_vec_.empty() && !inv_std_vec_.empty(),
"mean and standard deviation can't be empty");

DALI_ENFORCE(
mean_vec_.size() == inv_std_vec_.size() || mean_vec_.size() == 1 || inv_std_vec_.size() == 1,
"`mean` and `stddev` must either be of the same size, be scalars, or one of them can be a "
"vector and the other a scalar.");

// Inverse the std-deviation
for (auto &element : inv_std_vec_) {
element = 1.f / element;
}

// Handle irregular mean/std argument lengths
auto args_size = std::max(mean_vec_.size(), inv_std_vec_.size());
if (mean_vec_.size() != inv_std_vec_.size()) {
if (mean_vec_.size() == 1)
mean_vec_.resize(args_size, mean_vec_[0]);

if (inv_std_vec_.size() == 1)
inv_std_vec_.resize(args_size, inv_std_vec_[0]);
}

if (std::is_same<Backend, GPUBackend>::value) {
kmgr_.Resize(1, 1);
} else {
Expand Down
13 changes: 12 additions & 1 deletion dali/test/python/test_operator_crop_mirror_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,16 @@ def crop_mirror_normalize_func(crop_z, crop_y, crop_x,
out = image[:, start_z:end_z, start_y:end_y, start_x:end_x, :]
D, H, W = out.shape[1], out.shape[2], out.shape[3]

if not mean:
mean = [0.0]
if not std:
std = [1.0]

if len(mean) == 1:
mean = C * mean
if len(std) == 1:
std = C * std

assert len(mean) == C and len(std) == C
inv_std = [np.float32(1.0) / np.float32(std[c]) for c in range(C)]
mean = np.float32(mean)
Expand Down Expand Up @@ -369,7 +376,11 @@ def check_cmn_random_data_vs_numpy(device, batch_size, output_dtype, input_layou

def test_cmn_random_data_vs_numpy():
norm_data = [ ([0., 0., 0.], [1., 1., 1.]),
([0.485 * 255, 0.456 * 255, 0.406 * 255], [0.229 * 255, 0.224 * 255, 0.225 * 255]) ]
([0.485 * 255, 0.456 * 255, 0.406 * 255], [0.229 * 255, 0.224 * 255, 0.225 * 255]),
([0.485 * 255, 0.456 * 255, 0.406 * 255], None),
([0.485 * 255, 0.456 * 255, 0.406 * 255], [255.0,]),
(None, [0.229 * 255, 0.224 * 255, 0.225 * 255]),
([128,], [0.229 * 255, 0.224 * 255, 0.225 * 255]) ]
output_layouts = {
"HWC" : ["HWC", "CHW"],
"FHWC" : ["FHWC", "FCHW"],
Expand Down

0 comments on commit 7975e9c

Please sign in to comment.