Skip to content

Commit

Permalink
regularized the naming convention in SimpleNetworkBuilder.cpp/h, and …
Browse files Browse the repository at this point in the history
…added comments from Kaisheng on network kinds
  • Loading branch information
frankseide committed Jan 22, 2016
1 parent 55b3ae9 commit 198d64d
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 63 deletions.
46 changes: 23 additions & 23 deletions Source/CNTK/SimpleNetworkBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,42 +24,42 @@ template <class ElemType>
ComputationNetworkPtr SimpleNetworkBuilder<ElemType>::BuildNetworkFromDescription()
{
ComputationNetworkPtr net;
switch (m_rnnType)
switch (m_standardNetworkKind)
{
case SIMPLENET:
net = BuildSimpleDNNFromDescription();
case FFDNNKind:
net = BuildFFDNNFromDescription();
break;
case SIMPLERNN:
net = BuildSimpleRNNFromDescription();
case RNNKind:
net = BuildRNNFromDescription();
break;
case LSTM:
case LSTMKind:
net = BuildLSTMNetworkFromDescription();
break;
case CLASSLSTM:
case ClassLSTMNetworkKind:
net = BuildClassLSTMNetworkFromDescription();
break;
case NCELSTM:
case NCELSTMNetworkKind:
net = BuildNCELSTMNetworkFromDescription();
break;
case CLASSLM:
net = BuildClassEntropyNetworkFromDescription();
case ClassEntropyRNNKind:
net = BuildClassEntropyRNNFromDescription();
break;
case LBLM:
case LogBilinearNetworkKind:
net = BuildLogBilinearNetworkFromDescription();
break;
case NPLM:
net = BuildNeuralProbNetworkFromDescription();
case DNNLMNetworkKind:
net = BuildDNNLMNetworkFromDescription();
break;
case CLSTM:
case ConditionalLSTMNetworkKind:
net = BuildConditionalLSTMNetworkFromDescription();
break;
#ifdef COMING_SOON
case RCRF:
net = BuildSeqTrnLSTMNetworkFromDescription();
case CRFLSTMNetworkKind:
net = BuildCRFLSTMNetworkFromDescription();
break;
#endif
default:
LogicError("BuildNetworkFromDescription: invalid m_rnnType %d", (int) m_rnnType);
LogicError("BuildNetworkFromDescription: invalid m_standardNetworkKind %d", (int) m_standardNetworkKind);
}

// post-process the network
Expand All @@ -69,7 +69,7 @@ ComputationNetworkPtr SimpleNetworkBuilder<ElemType>::BuildNetworkFromDescriptio
}

template <class ElemType>
ComputationNetworkPtr SimpleNetworkBuilder<ElemType>::BuildSimpleDNNFromDescription()
ComputationNetworkPtr SimpleNetworkBuilder<ElemType>::BuildFFDNNFromDescription()
{

ComputationNetworkBuilder<ElemType> builder(*m_net);
Expand Down Expand Up @@ -168,7 +168,7 @@ ComputationNetworkPtr SimpleNetworkBuilder<ElemType>::BuildSimpleDNNFromDescript

// Note: while ComputationNode and CompuationNetwork are (supposed to be) independent of ElemType, it is OK to keep this class dependent.
template <class ElemType>
ComputationNetworkPtr SimpleNetworkBuilder<ElemType>::BuildSimpleRNNFromDescription()
ComputationNetworkPtr SimpleNetworkBuilder<ElemType>::BuildRNNFromDescription()
{
ComputationNetworkBuilder<ElemType> builder(*m_net);
if (m_net->GetTotalNumberOfNodes() < 1) // not built yet
Expand Down Expand Up @@ -276,7 +276,7 @@ ComputationNetworkPtr SimpleNetworkBuilder<ElemType>::BuildSimpleRNNFromDescript
}

template <class ElemType>
ComputationNetworkPtr SimpleNetworkBuilder<ElemType>::BuildClassEntropyNetworkFromDescription()
ComputationNetworkPtr SimpleNetworkBuilder<ElemType>::BuildClassEntropyRNNFromDescription()
{
ComputationNetworkBuilder<ElemType> builder(*m_net);

Expand All @@ -292,7 +292,7 @@ ComputationNetworkPtr SimpleNetworkBuilder<ElemType>::BuildClassEntropyNetworkFr
ComputationNodePtr wrd2cls, cls2idx, clslogpostprob, clsweight;

if (m_vocabSize != m_layerSizes[numHiddenLayers + 1])
RuntimeError("BuildClassEntropyNetworkFromDescription : vocabulary size should be the same as the output layer size");
RuntimeError("BuildClassEntropyRNNFromDescription : vocabulary size should be the same as the output layer size");

input = builder.CreateSparseInputNode(L"features", m_layerSizes[0]);
m_net->FeatureNodes().push_back(input);
Expand Down Expand Up @@ -615,7 +615,7 @@ ComputationNetworkPtr SimpleNetworkBuilder<ElemType>::BuildLogBilinearNetworkFro
}

template <class ElemType>
ComputationNetworkPtr SimpleNetworkBuilder<ElemType>::BuildNeuralProbNetworkFromDescription()
ComputationNetworkPtr SimpleNetworkBuilder<ElemType>::BuildDNNLMNetworkFromDescription()
{
ComputationNetworkBuilder<ElemType> builder(*m_net);
if (m_net->GetTotalNumberOfNodes() < 1) // not built yet
Expand Down Expand Up @@ -952,7 +952,7 @@ shared_ptr<ComputationNode<ElemType>> /*ComputationNodePtr*/ SimpleNetworkBuilde
#ifdef COMING_SOON

template <class ElemType>
ComputationNetworkPtr SimpleNetworkBuilder<ElemType>::BuildSeqTrnLSTMNetworkFromDescription()
ComputationNetworkPtr SimpleNetworkBuilder<ElemType>::BuildCRFLSTMNetworkFromDescription()
{
ComputationNetworkBuilder<ElemType> builder(*m_net);
if (m_net->GetTotalNumberOfNodes() < 1) // not built yet
Expand Down
66 changes: 31 additions & 35 deletions Source/CNTK/SimpleNetworkBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,22 @@ namespace Microsoft { namespace MSR { namespace CNTK {

#define MAX_DEPTH 20

enum RNNTYPE
// the standard network kinds that can be built with SimpleNetworkBuilder
enum StandardNetworkKind
{
SIMPLENET = 0, // no recurrent connections
SIMPLERNN = 1,
LSTM = 2,
DEEPRNN = 4,
CLASSLM = 8,
LBLM = 16,
NPLM = 32,
CLASSLSTM = 64,
NCELSTM = 128,
CLSTM = 256,
RCRF = 512
// basic
FFDNNKind = 0, // basic feed-forward
RNNKind = 1, // basic RNN
LSTMKind = 2, // basic LSTM
// class-based
ClassEntropyRNNKind = 8, // class-based RNN
ClassLSTMNetworkKind = 64, // class-based LSTM
// advanced
LogBilinearNetworkKind = 16, // log-bilinear model for language modeling
DNNLMNetworkKind = 32, // DNN-based LM
NCELSTMNetworkKind = 128, // NCE LSTM
ConditionalLSTMNetworkKind = 256, // conditional LM for text generation
CRFLSTMNetworkKind = 512, // sequential LSTM
};

enum class TrainingCriterion : int // TODO: camel-case these
Expand Down Expand Up @@ -165,27 +168,25 @@ class SimpleNetworkBuilder

stringargvector strType = str_rnnType;
if (std::find(strType.begin(), strType.end(), L"SIMPLENET") != strType.end())
m_rnnType = SIMPLENET;
m_standardNetworkKind = FFDNNKind;
else if (std::find(strType.begin(), strType.end(), L"SIMPLERNN") != strType.end())
m_rnnType = SIMPLERNN;
m_standardNetworkKind = RNNKind;
else if (std::find(strType.begin(), strType.end(), L"LSTM") != strType.end())
m_rnnType = LSTM;
else if (std::find(strType.begin(), strType.end(), L"DEEPRNN") != strType.end())
m_rnnType = DEEPRNN;
m_standardNetworkKind = LSTMKind;
else if (std::find(strType.begin(), strType.end(), L"CLASSLM") != strType.end())
m_rnnType = CLASSLM;
m_standardNetworkKind = ClassEntropyRNNKind;
else if (std::find(strType.begin(), strType.end(), L"LBLM") != strType.end())
m_rnnType = LBLM;
m_standardNetworkKind = LogBilinearNetworkKind;
else if (std::find(strType.begin(), strType.end(), L"NPLM") != strType.end())
m_rnnType = NPLM;
m_standardNetworkKind = DNNLMNetworkKind;
else if (std::find(strType.begin(), strType.end(), L"CLASSLSTM") != strType.end())
m_rnnType = CLASSLSTM;
m_standardNetworkKind = ClassLSTMNetworkKind;
else if (std::find(strType.begin(), strType.end(), L"NCELSTM") != strType.end())
m_rnnType = NCELSTM;
m_standardNetworkKind = NCELSTMNetworkKind;
else if (std::find(strType.begin(), strType.end(), L"CLSTM") != strType.end())
m_rnnType = CLSTM;
m_standardNetworkKind = ConditionalLSTMNetworkKind;
else if (std::find(strType.begin(), strType.end(), L"CRF") != strType.end())
m_rnnType = RCRF;
m_standardNetworkKind = CRFLSTMNetworkKind;
else
InvalidArgument("InitRecurrentConfig: unknown value for rnnType parameter '%ls'", strType[0].c_str());
}
Expand Down Expand Up @@ -241,21 +242,16 @@ class SimpleNetworkBuilder

ComputationNetworkPtr BuildNetworkFromDbnFile(const std::wstring& dbnModelFileName); // legacy support for fseide's Microsoft-internal tool "DBN.exe"

RNNTYPE RnnType()
{
return m_rnnType;
}

protected:

ComputationNetworkPtr BuildSimpleDNNFromDescription();
ComputationNetworkPtr BuildSimpleRNNFromDescription();
ComputationNetworkPtr BuildClassEntropyNetworkFromDescription();
ComputationNetworkPtr BuildFFDNNFromDescription();
ComputationNetworkPtr BuildRNNFromDescription();
ComputationNetworkPtr BuildClassEntropyRNNFromDescription();
ComputationNetworkPtr BuildLogBilinearNetworkFromDescription();
ComputationNetworkPtr BuildNeuralProbNetworkFromDescription();
ComputationNetworkPtr BuildDNNLMNetworkFromDescription();
ComputationNetworkPtr BuildLSTMNetworkFromDescription();
#ifdef COMING_SOON
ComputationNetworkPtr BuildSeqTrnLSTMNetworkFromDescription();
ComputationNetworkPtr BuildCRFLSTMNetworkFromDescription();
#endif
ComputationNetworkPtr BuildClassLSTMNetworkFromDescription();
ComputationNetworkPtr BuildConditionalLSTMNetworkFromDescription();
Expand Down Expand Up @@ -343,7 +339,7 @@ class SimpleNetworkBuilder
// recurrent network
intargvector m_recurrentLayers;
float m_defaultHiddenActivity;
RNNTYPE m_rnnType;
StandardNetworkKind m_standardNetworkKind;
int m_maOrder; // MA model order

bool m_constForgetGateValue;
Expand Down
2 changes: 1 addition & 1 deletion Source/Math/CPUMatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5290,7 +5290,7 @@ void CPUMatrix<ElemType>::RCRFBackwardCompute(const CPUMatrix<ElemType>& alpha,
}
};

/// the kernel function for RCRF backward computation
/// the kernel function for RCRF backward computation
template <class ElemType>
void CPUMatrix<ElemType>::_rcrfBackwardCompute(size_t t, size_t k, const CPUMatrix<ElemType>& alpha,
CPUMatrix<ElemType>& beta,
Expand Down
8 changes: 4 additions & 4 deletions Source/Math/GPUMatrixCUDAKernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -4575,7 +4575,7 @@ __global__ void _assignElementProductOfWithShift(
us[id] = a[id] * b[tmpidb];
}

/// minus 1 at a specific position
// minus 1 at a specific position
template <class ElemType>
__global__ void _minusOneAt(
ElemType* c,
Expand All @@ -4589,8 +4589,8 @@ __global__ void _minusOneAt(
c[id] = c[id] - 1.0;
}

/// the kernel function for RCRF backward computation
/// assume a column slice of input and output
// the kernel function for RCRF backward computation
// assume a column slice of input and output
template <class ElemType>
__global__ void _rcrfBackwardCompute(
const size_t iNumPos,
Expand Down Expand Up @@ -4662,7 +4662,7 @@ __global__ void _rcrfBackwardCompute(
// __syncthreads();
}

/// the kernel function for RCRF backward computation
/// the kernel function for CRFLSTMNetwork backward computation
/// assume a column slice of input and output
template <class ElemType>
__global__ void _rcrfBackwardCompute(
Expand Down

0 comments on commit 198d64d

Please sign in to comment.