Skip to content

Commit

Permalink
Adding full support for 'alpha' attribute in ELU op in ONNX and CNTK.
Browse files Browse the repository at this point in the history
  • Loading branch information
Spandan Tiwari committed Jun 28, 2018
1 parent c589619 commit 69938f2
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 8 deletions.
5 changes: 5 additions & 0 deletions Source/CNTKv2LibraryDll/API/CNTKLibrary.h
Original file line number Diff line number Diff line change
Expand Up @@ -4759,6 +4759,11 @@ namespace CNTK
///
CNTK_API FunctionPtr ELU(const Variable& operand, const std::wstring& name = L"");

///
/// Create an instance of the CNTK built-in elementwise exponential linear unit operation with specified alpha and with the specified input operand.
///
CNTK_API FunctionPtr ELU(const Variable& operand, double alpha, const std::wstring& name = L"");

///
/// Create an instance of the CNTK built-in elementwise scaled exponential linear unit operation with the specified input operand.
///
Expand Down
13 changes: 13 additions & 0 deletions Source/CNTKv2LibraryDll/Function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2749,6 +2749,19 @@ namespace CNTK
return UnaryOp(PrimitiveOpType::ELU, operand, Dictionary(), name);
}

FunctionPtr ELU(const Variable& operand, double alpha, const std::wstring& name)
{
auto additionalProperties = Dictionary();
additionalProperties[PrimitiveFunction::AttributeNameAlpha] = alpha;

auto operandPlaceholder = PlaceholderVariable();
auto lessThanZero = Less(operandPlaceholder, Constant::Scalar(operand.GetDataType(), 0.0));
auto result = ElementSelect(lessThanZero,
ElementTimes(Constant::Scalar(operand.GetDataType(), alpha), ELU(operandPlaceholder, name + L"_ELU")),
operandPlaceholder);
return AsBlock(std::move(result), { { operandPlaceholder, operand } }, std::move(additionalProperties), L"ELU", name);
}

FunctionPtr SELU(const Variable& operand, double gamma, double alpha, const std::wstring& name)
{
auto additionalProperties = Dictionary();
Expand Down
4 changes: 3 additions & 1 deletion Source/CNTKv2LibraryDll/proto/onnx/CNTKToONNX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2882,7 +2882,9 @@ void CNTKToONNXHelper::CopyAttributes(const FunctionPtr& src, LotusIR::Node* nod
}
else if (src->OpName() == L"ELU")
{
auto alpha = 1.0f;
float alpha = 1.0f;
if (src->Attributes().Contains(L"alpha"))
alpha = (float)src->Attributes()[L"alpha"].Value<double>();
node->AddAttribute("alpha", alpha);
}
else if (src->OpName() == L"LeakyReLU")
Expand Down
3 changes: 2 additions & 1 deletion Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2306,7 +2306,8 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
}
else if (onnxOpName == "Elu")
{
FunctionPtr cntkFunction = ELU(inputs[0], ToWString(node->Name()));
double alpha = static_cast<double>(GetNamedAttributeAsFloat(node, "alpha", 1.0f));
FunctionPtr cntkFunction = ELU(inputs[0], alpha, ToWString(node->Name()));
return cntkFunction;
}
else if (onnxOpName == "Exp")
Expand Down
3 changes: 2 additions & 1 deletion Source/CNTKv2LibraryDll/proto/onnx/Operators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ namespace ONNX
} } },
{ L"ELU", { {
{ L"ELU", "Elu" },
// { L"", "alpha" },
{ L"alpha", "alpha" },
} } },
{ L"Exp", { {
{ L"Exp", "Exp" },
Expand Down Expand Up @@ -465,6 +465,7 @@ namespace ONNX
}
std::unordered_map<std::wstring, std::set<size_t>> Operators::_cntkBlockOPInvalidIndices = {
{ L"Clip",{ 1, 2 } },
{ L"ELU",{ 0, 1 } },
{ L"LeakyReLU",{ 0, 1 } },
{ L"SELU",{ 0, 1, 2 } },
{ L"PReLU",{ 0 } },
Expand Down
6 changes: 3 additions & 3 deletions bindings/python/cntk/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1363,10 +1363,10 @@ def relu(x, name=''):


@typemap
def elu(x, name=''):
def elu(x, alpha=1.0, name=''):
'''
Exponential linear unit operation. Computes the element-wise exponential linear
of ``x``: ``max(x, 0)`` for ``x >= 0`` and ``x``: ``exp(x)-1`` otherwise.
of ``x``: ``max(x, 0)`` for ``x >= 0`` and ``x``: ``alpha * (exp(x)-1)`` otherwise.
The output tensor has the same shape as ``x``.
Expand All @@ -1384,7 +1384,7 @@ def elu(x, name=''):
'''
from cntk.cntk_py import elu
x = sanitize_input(x)
return elu(x, name)
return elu(x, alpha, name)

@typemap
def selu(x, scale=1.0507009873554804934193349852946, alpha=1.6732632423543772848170429916717, name=''):
Expand Down
8 changes: 6 additions & 2 deletions bindings/python/cntk/tests/onnx_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,10 +380,14 @@ def test_Elu(tmpdir, dtype):
model = C.elu(data)
verify_no_input(model, tmpdir, 'Elu_0')

x = C.input_variable(data.shape)
model = C.elu(x)
x1 = C.input_variable(data.shape)
model = C.elu(x1)
verify_one_input(model, data, tmpdir, 'Elu_1')

x2 = C.input_variable(data.shape)
model = C.elu(x2, alpha=2.0)
verify_one_input(model, data, tmpdir, 'Elu_2')

#Equal
@pytest.mark.parametrize("dtype", DType_Config)
def test_Equal(tmpdir, dtype):
Expand Down

0 comments on commit 69938f2

Please sign in to comment.