Skip to content

Commit

Permalink
Added eps and engine parameters to BN node.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexey Kamenev committed Feb 12, 2016
1 parent 9ad6147 commit 11fe2c6
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Examples/Image/Miscellaneous/ImageNet/ResNet/Macros.ndl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ ConvBNLayerW(W, inp, outMap, kW, kH, hStride, vStride, bValue, scValue, expAvg)
isd = Parameter(outMap, 1, init = fixedValue, value = 0, needGradient = false)

c = Convolution(W, inp, kW, kH, outMap, hStride, vStride, zeroPadding = true, imageLayout = "cudnn")
y = BatchNormalization(c, sc, b, m, isd, eval = false, spatial = true, expAvgFactor = expAvg, imageLayout = "cudnn")
y = BatchNormalization(c, sc, b, m, isd, eval = false, spatial = true, expAvgFactor = expAvg, epsilon = 0.000000001, imageLayout = "cudnn")
}

ConvBNLayer(inp, outMap, inWCount, kW, kH, hStride, vStride, wScale, bValue, scValue, expAvg)
Expand Down
2 changes: 1 addition & 1 deletion Source/CNTK/BrainScript/ExperimentalNetworkBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ wstring computationNodes = // TODO: use actual TypeName() here? would first need
L"ColumnwiseCrossProduct = KhatriRaoProduct // deprecated \n" // TODO: should it be deprecated? It is described as easier to understand in the CNTKBook.
L"ClassificationError = ErrorPrediction \n"
L"Delay = PastValue \n" // TODO: should it allow negative offsets and an if test here?
L"BatchNormalization(input, scale, bias, runMean, runInvStdDev, eval, spatial, expAvgFactor, imageLayout='CHW', tag='') = new ComputationNode [ operation = 'BatchNormalization' ; inputs = (input : scale : bias : runMean : runInvStdDev) /*plus the function args*/ ]\n"
L"BatchNormalization(input, scale, bias, runMean, runInvStdDev, eval, spatial, expAvgFactor, epsilon, engine, imageLayout='CHW', tag='') = new ComputationNode [ operation = 'BatchNormalization' ; inputs = (input : scale : bias : runMean : runInvStdDev) /*plus the function args*/ ]\n"
// standard nodes. We use macros to define these strings.
#define UnaryStandardNode(Op, a) L## #Op L"(" L## #a L", tag='') = new ComputationNode [ operation = '" L## #Op L"' ; inputs = " L## #a L" /*plus the function args*/ ]\n"
#define BinaryStandardNode(Op, a, b) L## #Op L"(" L## #a L", " L## #b L", tag='') = new ComputationNode [ operation = '" L## #Op L"' ; inputs = (" L## #a L" : " L## #b L") /*plus the function args*/ ]\n"
Expand Down
11 changes: 10 additions & 1 deletion Source/CNTK/SynchronousExecutionEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,9 +400,18 @@ void SynchronousNodeEvaluator<ElemType>::Evaluate(NDLNode<ElemType>* node, const
bool eval = node->GetOptionalParameter("eval", "false");
bool spatial = node->GetOptionalParameter("spatial", "false");
double expAvgFactor = node->GetOptionalParameter("expAvgFactor", "1.0");
double epsilon = node->GetOptionalParameter("epsilon", "0.00001");
std::wstring bnEngineS = node->GetOptionalParameter("engine", "cntk");
bool useCntkEngine;
if (bnEngineS == L"cntk")
useCntkEngine = true;
else if (bnEngineS == L"cudnn")
useCntkEngine = false;
else
InvalidArgument("Unsupported batch normalization engine, choose either \"cntk\"(default) or \"cudnn\".");
ImageLayoutKind imageLayoutKind = ImageLayoutKindFrom(node->GetOptionalParameter("imageLayout", "CHW"));

nodePtr = builder.BatchNormalization(nullptr, nullptr, nullptr, nullptr, nullptr, eval, spatial, expAvgFactor, imageLayoutKind, name);
nodePtr = builder.BatchNormalization(nullptr, nullptr, nullptr, nullptr, nullptr, eval, spatial, expAvgFactor, epsilon, useCntkEngine, imageLayoutKind, name);
}
}
else
Expand Down
4 changes: 2 additions & 2 deletions Source/ComputationNetworkLib/ComputationNetworkBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -603,9 +603,9 @@ shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::Looku
template <class ElemType>
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::BatchNormalization(const ComputationNodePtr input,
const ComputationNodePtr scale, const ComputationNodePtr bias, const ComputationNodePtr runMean, const ComputationNodePtr runInvStdDev,
bool eval, bool spatial, double expAvgFactor, ImageLayoutKind imageLayoutKind, const std::wstring nodeName)
bool eval, bool spatial, double expAvgFactor, double epsilon, bool useCntkEngine, ImageLayoutKind imageLayoutKind, const std::wstring nodeName)
{
return net.AddNodeToNetAndAttachInputs(New<BatchNormalizationNode<ElemType>>(net.GetDeviceId(), nodeName, eval, spatial, expAvgFactor, imageLayoutKind),
return net.AddNodeToNetAndAttachInputs(New<BatchNormalizationNode<ElemType>>(net.GetDeviceId(), nodeName, eval, spatial, expAvgFactor, epsilon, useCntkEngine, imageLayoutKind),
input, scale, bias, runMean, runInvStdDev);
}

Expand Down
3 changes: 2 additions & 1 deletion Source/ComputationNetworkLib/ComputationNetworkBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ class ComputationNetworkBuilder
// The following functions create nodes and link them to the network and their inputs.
// TODO: Do we need both this set and the one above that does not add inputs? Can they share more code?
ComputationNodePtr BatchNormalization(const ComputationNodePtr input, const ComputationNodePtr scale, const ComputationNodePtr bias,
const ComputationNodePtr runMean, const ComputationNodePtr runInvStdDev, bool eval = false, bool spatial = false, double expAvgFactor = 1, ImageLayoutKind imageLayoutKind = ImageLayoutKind::CHW, const std::wstring nodeName = L"");
const ComputationNodePtr runMean, const ComputationNodePtr runInvStdDev, bool eval = false, bool spatial = false, double expAvgFactor = 1, double epsilon = 1e-5, bool useCntkEngine = true,
ImageLayoutKind imageLayoutKind = ImageLayoutKind::CHW, const std::wstring nodeName = L"");
ComputationNodePtr Convolution(const ComputationNodePtr weight,
const ComputationNodePtr inputValues,
const size_t kernelWidth, const size_t kernelHeight, const size_t outputChannels,
Expand Down
22 changes: 16 additions & 6 deletions Source/ComputationNetworkLib/TrainingNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1613,16 +1613,17 @@ class BatchNormalizationNode : public ComputationNode<ElemType>, public NumInput

public:
BatchNormalizationNode(DEVICEID_TYPE deviceId, const wstring& name)
: Base(deviceId, name), m_eval(false), m_spatial(false), m_expAvgFactor(0), m_mbCount(0), m_imageLayoutKind(ImageLayoutKind::CHW)
: Base(deviceId, name), m_eval(false), m_spatial(false), m_expAvgFactor(0), m_epsilon(0), m_useCntkEngine(true), m_mbCount(0), m_imageLayoutKind(ImageLayoutKind::CHW)
{
}
BatchNormalizationNode(DEVICEID_TYPE deviceId, const wstring& name, bool eval, bool spatial, double expAvgFactor, ImageLayoutKind imageLayoutKind)
: Base(deviceId, name), m_eval(eval), m_spatial(spatial), m_expAvgFactor(expAvgFactor), m_imageLayoutKind(imageLayoutKind), m_mbCount(0)
BatchNormalizationNode(DEVICEID_TYPE deviceId, const wstring& name, bool eval, bool spatial, double expAvgFactor, double epsilon, bool useCntkEngine, ImageLayoutKind imageLayoutKind)
: Base(deviceId, name), m_eval(eval), m_spatial(spatial), m_expAvgFactor(expAvgFactor), m_epsilon(epsilon), m_useCntkEngine(useCntkEngine),
m_imageLayoutKind(imageLayoutKind), m_mbCount(0)
{
}
BatchNormalizationNode(const ScriptableObjects::IConfigRecordPtr configp)
: BatchNormalizationNode(configp->Get(L"deviceId"), L"<placeholder>", configp->Get(L"eval"), configp->Get(L"spatial"), configp->Get(L"expAvgFactor"),
ImageLayoutKindFrom(configp->Get(L"imageLayout")))
configp->Get(L"epsilon"), configp->Get(L"useCntkEngine"), ImageLayoutKindFrom(configp->Get(L"imageLayout")))
{
AttachInputs(configp, this->GetExpectedNumInputs());
}
Expand Down Expand Up @@ -1762,7 +1763,7 @@ class BatchNormalizationNode : public ComputationNode<ElemType>, public NumInput
m_saveInvStdDev->Resize(runMean.GetNumRows(), runMean.GetNumCols());

m_convEng->NormalizeBatch(*m_inT, sliceInputValue, *m_scaleBiasT, scale, bias, m_spatial, expAvgFactor, runMean, runInvStdDev,
sliceOutputValue, 1e-5, *m_saveMean, *m_saveInvStdDev);
sliceOutputValue, m_epsilon, *m_saveMean, *m_saveInvStdDev);

m_mbCount++;
}
Expand All @@ -1784,12 +1785,17 @@ class BatchNormalizationNode : public ComputationNode<ElemType>, public NumInput

if (isFinalValidationPass)
{
if (m_spatial && m_imageLayoutKind != CHW)
{
InvalidArgument("Batch normalization currently supports only cuDNN (CHW) format. Please specify imageLayout=\"cudnn\" in BatchNormalization node in your NDL/BrainScript.");
}

auto shape = GetSampleLayout();

if (m_factory == nullptr)
m_factory = ConvolutionEngineFactory<ElemType>::Create(m_deviceId, ConvolutionEngineFactory<ElemType>::EngineType::Auto, m_imageLayoutKind);
if (m_convEng == nullptr)
m_convEng = m_factory->CreateConvEngine(m_deviceId, 0);
m_convEng = m_factory->CreateConvEngine(m_deviceId, 0, m_useCntkEngine ? BatchNormImpl::Cntk : BatchNormImpl::CuDnn);
if (m_spatial)
{
auto dims = ImageDimensions(shape, m_imageLayoutKind);
Expand Down Expand Up @@ -1872,6 +1878,10 @@ class BatchNormalizationNode : public ComputationNode<ElemType>, public NumInput
bool m_spatial;
// Smoothing factor.
double m_expAvgFactor;
// Epsilon used to compute inverse std deviation.
double m_epsilon;
// Whether to use CNTK or cuDNN BN implementation.
bool m_useCntkEngine;
// Layout (e.g. CHW).
ImageLayoutKind m_imageLayoutKind;
// Minibatch count, used to compute cumulative moving average.
Expand Down
2 changes: 1 addition & 1 deletion Tests/UnitTests/MathTests/ConvolutionEngineTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ static bool CheckEqual(const Matrix<T>& result, const Matrix<T>& reference, std:
float b = ref[badIndex];
std::stringstream ss;
ss << count << " mismatch" << (count > 1 ? "es" : "") << ", first mismatch at " << badIndex << ", " << a << " != " << b
<< ", rel = " << std::abs(a - b) << ", abs = " << (std::abs(a - b) / std::max(std::abs(a), std::abs(b)));
<< ", rel = " << (std::abs(a - b) / std::max(std::abs(a), std::abs(b))) << ", abs = " << std::abs(a - b);
msg = ss.str();
}
return count == 0;
Expand Down

0 comments on commit 11fe2c6

Please sign in to comment.