Skip to content

Commit

Permalink
added new interface IFreezable to tell a node to freeze itself, in or…
Browse files Browse the repository at this point in the history
…der to allow BatchNormalization to honor CloneFunction (..., parameters="constant")
  • Loading branch information
frankseide committed Jul 22, 2016
1 parent ce350dd commit 0270010
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 41 deletions.
6 changes: 3 additions & 3 deletions Source/ComputationNetworkLib/ComputationNetworkScripting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ class CloneFunctionConfigLambda : public ConfigLambda
for (let& node : allInputs)
{
// add parameters that are to be cloned to dependent set
if (parameterTreatment != ParameterTreatment::shared && node->Is<IParameterNode>())
if (parameterTreatment != ParameterTreatment::shared && node->Is<IFreezable>())
dependentSet.insert(node);
// if at least one input is in the dependent set then this node is, too
else
Expand Down Expand Up @@ -603,8 +603,8 @@ class CloneFunctionConfigLambda : public ConfigLambda
let newName = exprName + L"." + node->GetName();
newNode = node->Duplicate(newName, CopyNodeFlags::copyNodeAll);
// make it read-only if desired
if (parameterTreatment == ParameterTreatment::constant)
newNode->SetLearningRateMultiplier(0);
if (parameterTreatment == ParameterTreatment::constant && newNode->Is<IFreezable>())
newNode->As<IFreezable>()->FreezeParameters();
// and that's our cloned node
clonedNodes[node] = newNode;
numCloned++;
Expand Down
6 changes: 3 additions & 3 deletions Source/ComputationNetworkLib/ComputationNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -1897,11 +1897,11 @@ class LateAttachingNode : public N, public ILateAttachingNode
struct IRecurrentNode { virtual int GetRecurrenceSteppingDirection() const = 0; };

// =======================================================================
// IParameterNode -- interface implemented by ComputationNodes that are parameters
// Note: There is possibly code that identifies parameters by the type name instead. Should be unified.
// IFreezable -- nodes that have parameters that can be frozen
// e.g. if a trained model is to be used as a fixed feature extractor for another
// =======================================================================

struct IParameterNode { virtual ~IParameterNode() { } };
struct IFreezable { virtual void FreezeParameters() { } };

// =======================================================================
// PreComputedNodeBase -- interface implemented by ComputationNodes that precompute
Expand Down
6 changes: 6 additions & 0 deletions Source/ComputationNetworkLib/InputAndParamNodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,12 @@ template <class ElemType>
PrintNodeValuesToFile(printValues, printMetadata, fstream);
}

template <class ElemType>
/*virtual*/ void LearnableParameter<ElemType>::FreezeParameters() /*override*/ // from IFreezable
{
SetLearningRateMultiplier(0);
}

template class LearnableParameter<float>;
template class LearnableParameter<double>;

Expand Down
5 changes: 4 additions & 1 deletion Source/ComputationNetworkLib/InputAndParamNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// -----------------------------------------------------------------------

template <class ElemType>
class LearnableParameter : public ComputationNode<ElemType>, public NumInputs<0>, public IParameterNode
class LearnableParameter : public ComputationNode<ElemType>, public NumInputs<0>, public IFreezable
{
typedef ComputationNode<ElemType> Base; UsingComputationNodeMembersBoilerplate;
static const std::wstring TypeName() { return L"LearnableParameter"; }
Expand Down Expand Up @@ -106,6 +106,9 @@ class LearnableParameter : public ComputationNode<ElemType>, public NumInputs<0>
void InferInputDimsFrom(const TensorShape& otherShape);

virtual void DumpNodeInfo(const bool printValues, const bool printMetadata, File& fstream) const override;

// called from CloneFunction(..., parameters="constant")
virtual void FreezeParameters() override; // from IFreezable
};

// -----------------------------------------------------------------------
Expand Down
88 changes: 55 additions & 33 deletions Source/ComputationNetworkLib/TrainingNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1534,8 +1534,8 @@ template class DropoutNode<float>;
template class DropoutNode<double>;

// -----------------------------------------------------------------------
// BatchNormalizationNode (input, scale, bias, runMean, runInvStdDev, spatial,
// normalizationTimeConstant = 0, blendTimeConstant = 0,
// BatchNormalizationNode (input, scale, bias, runMean, runInvStdDev,
// spatial, normalizationTimeConstant = 0, blendTimeConstant = 0,
// epsilon = 0.00001,
// useCntkEngine = true, imageLayout = 'cudnn')
//
Expand Down Expand Up @@ -1573,7 +1573,7 @@ template class DropoutNode<double>;
// * imageLayout is the image layout. Only cudnn is supported at present.
// -----------------------------------------------------------------------
template <class ElemType>
class BatchNormalizationNode : public ComputationNode<ElemType>, public NumInputs<5>
class BatchNormalizationNode : public ComputationNode<ElemType>, public NumInputs<5>, public IFreezable
{
typedef ComputationNode<ElemType> Base; UsingComputationNodeMembersBoilerplate;
static const std::wstring TypeName() { return L"BatchNormalization"; }
Expand Down Expand Up @@ -1745,52 +1745,50 @@ class BatchNormalizationNode : public ComputationNode<ElemType>, public NumInput

Matrix<ElemType> sliceOutputValue = ValueFor(fr);

// are we training or in inference mode?
// In inference mode, running estimates are used as the sole estimates, while the MB values are not used at all.
if (Input(3)->IsParameterUpdateRequired() ^ Input(4)->IsParameterUpdateRequired())
InvalidArgument("BatchNormalization: Either both or none of %ls and %ls must be enabled for model update.",
Input(3)->NodeDescription().c_str(), Input(4)->NodeDescription().c_str());
bool inferenceMode =
!Environment().IsTraining() || // we are actually inferring
!Input(3)->IsParameterUpdateRequired(); // we are training, but this piece of network has been frozen (e.g. as a fixed feature extractor)

// determine the factors from the time constants
double expAvgFactor;
double blendFactor;
if (inferenceMode) // in inference mode, only use long-term mean and do not update running estimates
double expAvgFactor; // weight for the new MB statistics in the running estimate. The previous value of the running statistics is kept with weight (1-this)
double blendFactor; // interpolation weight for the running statistics (the current MB statistics are weighted with 1-this)
if (!Environment().IsTraining()) // in inference mode, only use long-term mean and do not update running estimates
{
expAvgFactor = 0; // no new contribution from current minibatch
blendFactor = 1.0; // estimate is taken 100% from the long-term running estimate
expAvgFactor = 0; // (m_normTimeConst == infinity) no new contribution from current minibatch
blendFactor = 1.0; // (m_blendTimeConst == infinity) estimate is taken 100% from the long-term running estimate
}
else
{
// (both time constants have edge cases of 0 and infinity which are special-cased below for numerical reasons)
double numSamples = (double)GetMBLayout()->GetActualNumSamples();
if (m_normTimeConst > 0)
{
// Convert to per-minibatch factor. Treat positivie infinity as if running mean/var parameters are "frozen"
// that is, do not require updates.
expAvgFactor = !isfinite(m_normTimeConst) ? 0 : (1.0 - exp(-numSamples / m_normTimeConst));
expAvgFactor = isfinite(m_normTimeConst)
? (1.0 - exp(-numSamples / m_normTimeConst))
: 0; // (same; special-cased for numerical reasons only)
}
else
{
// REVIEW alexeyk: hack, m_normTimeConst < 0 is used to compute CMA.
expAvgFactor = (m_normTimeConst < 0) ? (1.0 / (1.0 + m_mbCount)) : 1.0;
expAvgFactor = (m_normTimeConst < 0)
? (1.0 / (1.0 + m_mbCount)) // (this is the hack case)
: 1.0; // (same as 'then' branch above; special-cased for numerical reasons only)
}

if (!isfinite(m_blendTimeConst))
blendFactor = 1.0;
if (isfinite(m_blendTimeConst))
blendFactor = m_blendTimeConst > 0
? (m_blendTimeConst / (m_blendTimeConst + numSamples)) // interpolate
: 0; // (same; special-casing for 0 only for numerical reasons)
else
blendFactor = m_blendTimeConst > 0 ? (m_blendTimeConst / (m_blendTimeConst + numSamples)) : 0;
blendFactor = 1.0; // (same; special-casing for 0 only for numerical reasons)
}

// TODO: These Resize() operations belong INSIDE Forward().
// Specifically, for blendFactor=1, they must come back resized to (0,0). This is how Backward() will know & use running ones instead.
// I am not fixing this now because I don't know how to identify all variants of Forward(), across engines, CPU/GPU etc.
if (blendFactor == 1.0)
fprintf(stderr, "WARNING WARNING WARNING: blendFactor=1\n")
fprintf(stderr, "WARNING WARNING WARNING: blendFactor=1\n");
if (blendFactor == 1.0
#if 1 // otherwise this crashes--due to cuDNN?
&& inferenceMode
#if 1 // otherwise this crashes--seems cuDNN still needs these?
&& !Environment().IsTraining()
#endif
)
{
Expand Down Expand Up @@ -1880,6 +1878,14 @@ class BatchNormalizationNode : public ComputationNode<ElemType>, public NumInput
m_blendTimeConst = blendTimeConstant;
}

// called from CloneFunction(..., parameters="constant")
// Once called, this node is put into inference mode.
virtual void FreezeParameters() override // from IFreezable
{
m_normTimeConst = std::numeric_limits<double>::infinity();
m_blendTimeConst = std::numeric_limits<double>::infinity();
}

private:
// Old versioning - do not use. Do not remove until we're sure there are no old models around.
struct VersionInfo
Expand All @@ -1894,34 +1900,50 @@ class BatchNormalizationNode : public ComputationNode<ElemType>, public NumInput
VersionInfo m_version;

private:
// --- configuration parameters

// Determines whether to use per-activation (used after non-convolutional layers like fully connected)
// or spatial (used after convolutional layers).
// TODO: This should not be a config option, but rather inferred from dimensions of the Parameters.
bool m_spatial;
// Time constant for running mean and variance.

// Time constant for estimating the running mean and variance.
// This is the time constant of a low-pass filter.
// If 0, running mean and variance just remember the last minibatch.
// If infinity, running mean and variance are not updated, like in inference mode.
double m_normTimeConst;
// Time constant for blending running mean/var and current minibatch mean/var.
// The main idea is to represent current minibatch statistics as MAP estimate, linear interpolation
// of smoothed and minibatch statistics.

// Equivalent sample count for blending running mean/var and current minibatch mean/var.
// Roughly, this specifies how many samples "worth" is the running statistics,
// relative to the current minibatch statistics.
// If 0, only use the current MB statistics. If infinity, use only the running mean, like in inference mode.
// The main idea is to estimate the mean/variance as a MAP estimate using the running mean/var as a prrior.
// This should make the method more robust to the case of very small minibatches,
// and also provides a meaningful interpretation of inference mode, where only the prior is used.
// Effectively, this ends up in a linear interpolation of running and minibatch statistics.
// The idea is due to Frank Seide et al.
// It should also work well in data parallelism scenario
// as opposed to plain vanilla BN implementation which would require aggregation of statistics
// from all nodes.
// It should also work well in data parallelism scenario, as opposed to plain vanilla BN implementation
// which would require aggregation of statistics from all nodes.
// REVIEW alexeyk: if this works, document it properly in Wiki.
double m_blendTimeConst;

// Epsilon used to compute inverse std deviation.
double m_epsilon;
// Whether to use CNTK or cuDNN BN implementation.
bool m_useCntkEngine;
// Layout (e.g. CHW).
ImageLayoutKind m_imageLayoutKind;

// --- working variables

// Minibatch count, used to compute cumulative moving average.
size_t m_mbCount;

// Interpolated actual mean/stddev values. Pre-computed on forward pass, also used in gradient computation.
shared_ptr<Matrix<ElemType>> m_saveMean;
shared_ptr<Matrix<ElemType>> m_saveInvStdDev;
// Temp buffer for scale and bias derivatives. Only used in BackpropTo(), carrying info from first call to subsequent calls.
// Not used for blendFactor=1.
// Not used for blendFactor=1 in CNTK engine.
shared_ptr<Matrix<ElemType>> m_dScale;
shared_ptr<Matrix<ElemType>> m_dBias;

Expand Down
2 changes: 1 addition & 1 deletion Source/Math/GPUMatrix.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3155,7 +3155,7 @@ void GPUMatrix<ElemType>::BatchNormalizationForward(const GPUMatrix<ElemType>& s
// REVIEW alexeyk: can be rolled into NormalizeBatchTraining to save bandwidth.
// TODO: add a 'beta' parameter to ScaleAndAdd()
Scale((ElemType)(1 - blendFactor), saveMean);
ScaleAndAdd((ElemType)blendFactor, runMean, saveMean);
ScaleAndAdd((ElemType)blendFactor, /*in*/ runMean, /*in/out*/ saveMean);
Scale((ElemType)(1 - blendFactor), saveInvStdDev);
ScaleAndAdd((ElemType)blendFactor, runInvStdDev, saveInvStdDev);
}
Expand Down

0 comments on commit 0270010

Please sign in to comment.