Skip to content

Commit

Permalink
Move scalar tests from common_nn to legacy_nn. (#5223)
Browse files Browse the repository at this point in the history
  • Loading branch information
gchanan authored and colesbury committed Feb 13, 2018
1 parent f96f3c3 commit 232530c
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 170 deletions.
169 changes: 0 additions & 169 deletions test/common_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,6 @@ def get_weight(m):
check_inplace=True,
desc='threshold_value'
),
dict(
module_name='Threshold',
constructor_args=(2, 1),
input_size=(),
check_inplace=True,
desc='threshold_value_scalar'
),
dict(
module_name='Threshold',
constructor_args=(2, 10),
Expand All @@ -73,23 +66,11 @@ def get_weight(m):
input_size=(2, 3, 4, 5),
check_inplace=True,
),
dict(
module_name='ReLU',
input_size=(),
check_inplace=True,
desc='scalar'
),
dict(
module_name='ReLU6',
input_size=(2, 3, 4, 5),
check_inplace=True,
),
dict(
module_name='ReLU6',
input_size=(),
check_inplace=True,
desc='scalar'
),
dict(
module_name='RReLU',
input_size=(1, 2, 2),
Expand All @@ -102,55 +83,25 @@ def get_weight(m):
desc='with_up_down',
test_cuda=False,
),
dict(
module_name='RReLU',
constructor_args=(0.1, 0.9),
input_size=(),
desc='with_up_down_scalar',
test_cuda=False,
),
dict(
module_name='Hardtanh',
input_size=(3, 2, 5),
reference_fn=lambda i, _: i.clamp(-1, 1),
),
dict(
module_name='Hardtanh',
input_size=(),
reference_fn=lambda i, _: i.clamp(-1, 1),
desc='scalar'
),
dict(
module_name='Sigmoid',
input_size=(2, 3, 4, 5)
),
dict(
module_name='Sigmoid',
input_size=(),
desc='scalar',
),
dict(
module_name='Tanh',
input_size=(2, 3, 4, 5)
),
dict(
module_name='Tanh',
input_size=(),
desc='scalar',
),
dict(
module_name='Softmax',
constructor_args=(1,),
input_size=(10, 20),
reference_fn=lambda i, _: torch.exp(i).div(torch.exp(i).sum(1, True).expand(10, 20)),
),
dict(
module_name='Softmax',
constructor_args=(0,),
input_size=(),
reference_fn=lambda i, _: torch.exp(i).div(torch.exp(i).sum(0, True)),
desc='scalar',
),
dict(
module_name='Softmax2d',
input_size=(1, 3, 10, 20),
Expand All @@ -169,36 +120,17 @@ def get_weight(m):
reference_fn=lambda i, _: torch.exp(i).div_(torch.exp(i).sum(1, False)).log_(),
desc='multiparam',
),
dict(
module_name='LogSoftmax',
constructor_args=(0,),
input_size=(),
reference_fn=lambda i, _: torch.exp(i).div_(torch.exp(i).sum(0, False)).log_(),
desc='multiparam_scalar',
),
dict(
module_name='ELU',
constructor_args=(2.,),
input_size=(3, 2, 5),
),
dict(
module_name='ELU',
constructor_args=(2.,),
input_size=(),
desc='scalar',
),
# TODO: reference function
dict(
module_name='Hardshrink',
constructor_args=(2.,),
input_size=(4, 3, 2, 4),
),
dict(
module_name='Hardshrink',
constructor_args=(2.,),
input_size=(),
desc='scalar',
),
dict(
module_name='LeakyReLU',
input_size=(3, 2, 5),
Expand All @@ -211,24 +143,11 @@ def get_weight(m):
check_inplace=True,
desc='with_negval'
),
dict(
module_name='LeakyReLU',
constructor_args=(0.5,),
input_size=(),
check_inplace=True,
desc='with_negval_scalar'
),
dict(
module_name='LogSigmoid',
input_size=(2, 3, 4),
reference_fn=lambda i, _: i.sigmoid().log(),
),
dict(
module_name='LogSigmoid',
input_size=(),
reference_fn=lambda i, _: i.sigmoid().log(),
desc='scalar'
),
dict(
module_name='Softplus',
input_size=(10, 20),
Expand All @@ -249,14 +168,6 @@ def get_weight(m):
((i * 2) <= -100).type_as(i) * 1. / 2. * torch.log(1 + torch.exp(2 * i))),
desc='beta_threshold',
),
dict(
module_name='Softplus',
constructor_args=(2, -100),
input_size=(),
reference_fn=(lambda i, _: ((i * 2) > -100).type_as(i) * i +
((i * 2) <= -100).type_as(i) * 1. / 2. * torch.log(1 + torch.exp(2 * i))),
desc='beta_threshold_scalar',
),
dict(
module_name='Softshrink',
input_size=(3, 2, 5),
Expand All @@ -267,12 +178,6 @@ def get_weight(m):
input_size=(3, 2, 5),
desc='lambda',
),
dict(
module_name='Softshrink',
constructor_args=(1,),
input_size=(),
desc='lambda_scalar',
),
dict(
module_name='CrossMapLRN2d',
constructor_args=(5, 5e-3, 1e-3, 2),
Expand All @@ -292,12 +197,6 @@ def get_weight(m):
desc='1d_multiparam',
reference_fn=lambda i, p: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
),
dict(
module_name='PReLU',
input_size=(),
reference_fn=lambda i, p: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
desc='scalar',
),
dict(
module_name='PReLU',
input_size=(2, 3, 4, 5),
Expand Down Expand Up @@ -329,12 +228,6 @@ def get_weight(m):
input_size=(3, 2, 5),
reference_fn=lambda i, _: i.div(1 + torch.abs(i)),
),
dict(
module_name='Softsign',
input_size=(),
reference_fn=lambda i, _: i.div(1 + torch.abs(i)),
desc='scalar',
),
dict(
module_name='Softmin',
constructor_args=(1,),
Expand All @@ -346,21 +239,10 @@ def get_weight(m):
input_size=(2, 3, 5, 10),
desc='multidim',
),
dict(
module_name='Softmin',
constructor_args=(0,),
input_size=(),
desc='scalar',
),
dict(
module_name='Tanhshrink',
input_size=(2, 3, 4, 5)
),
dict(
module_name='Tanhshrink',
input_size=(),
desc='scalar',
),
]


Expand Down Expand Up @@ -547,13 +429,6 @@ def torch_randn(sizes, requires_grad=False):
reference_fn=lambda i, t, _: 1. / i.numel() *
sum((a - b).abs().sum() for a, b in zip(i, t)),
),
dict(
module_name='L1Loss',
input_size=(),
target_size=(),
reference_fn=lambda i, t, _: 1. / i.numel() * (i - t).abs().sum(),
desc='scalar',
),
dict(
module_name='NLLLoss',
input_fn=lambda: torch.rand(15, 10).log(),
Expand Down Expand Up @@ -605,30 +480,13 @@ def torch_randn(sizes, requires_grad=False):
kldivloss_reference(i, t, get_size_average(m), reduce=True),
check_no_size_average=True,
),
dict(
module_name='KLDivLoss',
input_fn=lambda: torch_rand(()).log(),
target_fn=lambda: torch_rand(()),
reference_fn=lambda i, t, m:
kldivloss_reference(i, t, get_size_average(m), reduce=True),
check_no_size_average=True,
desc='scalar',
),
dict(
module_name='MSELoss',
input_size=(2, 3, 4, 5),
target_size=(2, 3, 4, 5),
reference_fn=lambda i, t, m: (i - t).abs().pow(2).sum() / (i.numel() if get_size_average(m) else 1),
check_no_size_average=True,
),
dict(
module_name='MSELoss',
input_size=(),
target_size=(),
reference_fn=lambda i, t, m: (i - t).abs().pow(2).sum() / (i.numel() if get_size_average(m) else 1),
check_no_size_average=True,
desc='scalar'
),
dict(
module_name='BCELoss',
input_fn=lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2),
Expand All @@ -647,16 +505,6 @@ def torch_randn(sizes, requires_grad=False):
desc='weights',
check_gradgrad=False,
),
dict(
module_name='BCELoss',
constructor_args_fn=lambda: (torch_rand(()),),
input_fn=lambda: torch_rand(()).clamp_(1e-2, 1 - 1e-2),
target_fn=lambda: torch_rand(()).gt(0).double(),
reference_fn=lambda i, t, m: -((t * i.log() + (1 - t) * (1 - i).log()) * get_weight(m)).sum() /
(i.numel() if get_size_average(m) else 1),
desc='scalar_weights',
check_gradgrad=False,
),
dict(
module_name='CrossEntropyLoss',
input_size=(15, 10),
Expand Down Expand Up @@ -687,14 +535,6 @@ def torch_randn(sizes, requires_grad=False):
desc='margin',
check_no_size_average=True,
),
dict(
module_name='HingeEmbeddingLoss',
constructor_args=(0.5,),
input_size=(),
target_fn=lambda: torch_randn(()).gt(0).double().mul_(2).sub(1),
desc='scalar_margin',
check_no_size_average=True,
),
dict(
module_name='MultiLabelMarginLoss',
input_size=(10,),
Expand Down Expand Up @@ -742,15 +582,6 @@ def torch_randn(sizes, requires_grad=False):
reference_fn=lambda i, t, m:
smoothl1loss_reference(i, t, size_average=get_size_average(m)),
),
dict(
module_name='SmoothL1Loss',
input_size=(),
target_size=(),
check_no_size_average=True,
reference_fn=lambda i, t, m:
smoothl1loss_reference(i, t, size_average=get_size_average(m)),
desc='scalar',
),
dict(
module_name='SoftMarginLoss',
input_size=(5, 5),
Expand Down
Loading

0 comments on commit 232530c

Please sign in to comment.