Skip to content

Commit

Permalink
more strict control of compiled/non-compiled state of networks in MEL…
Browse files Browse the repository at this point in the history
…. All editing operations invalidate compilation, and Saving the model will validate the model if changed (if not it was validated during Load)
  • Loading branch information
frankseide committed Jan 14, 2016
1 parent c89b632 commit 40cbeac
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 11 deletions.
5 changes: 2 additions & 3 deletions Source/CNTK/ModelEditLanguage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,7 @@ void MELScript<ElemType>::CallFunction(const std::string& p_name, const ConfigPa

// validate the network before we save it out
ProcessNDLScript(m_netNdlDefault, ndlPassAll, true);

cn->Save(fileName);
cn->SaveEdited(fileName);
}
else if (EqualInsensitive(name, "SaveModel"))
{
Expand All @@ -219,7 +218,7 @@ void MELScript<ElemType>::CallFunction(const std::string& p_name, const ConfigPa

// validate and finish the second pass through NDL if any in-line NDL was defined
ProcessNDLScript(netNdl, ndlPassAll, true);
netNdl->cn->Save(fileName);
netNdl->cn->SaveEdited(fileName);
}
else if (EqualInsensitive(name, "SetDefaultModel"))
{
Expand Down
10 changes: 10 additions & 0 deletions Source/ComputationNetworkLib/ComputationNetwork.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// break cycles
// BUGBUG: This only works if nodes are not shared across networks.
// Once we allow that (BrainScript editing), we need proper cycle detectors. Luckily, we know our cycles, so it won't be too hard.
// Or just use weak ptrs.
for (auto & iter : m_nameToNodeMap)
iter.second->DetachInputs();

Expand All @@ -74,8 +75,17 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// serialization
// -----------------------------------------------------------------------

// after after editing--network is possibly not validated/compiled
void ComputationNetwork::SaveEdited(const wstring& fileName, const FileOptions fileFormat)
{
if (!IsCompiled())
CompileNetwork();
Save(fileName, fileFormat);
}

void ComputationNetwork::Save(const wstring& fileName, const FileOptions fileFormat) const
{
VerifyIsCompiled("Save");
// In case of parallel training only the main node should we saving the model to prevent
// the parallel training nodes from colliding to write the same file
// TODO: This does not belong here.
Expand Down
2 changes: 2 additions & 0 deletions Source/ComputationNetworkLib/ComputationNetwork.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class ComputationNetwork : public ScriptableObjects::Object, public ScriptableOb
// -----------------------------------------------------------------------

void Save(const std::wstring& fileName, const FileOptions fileFormat = FileOptions::fileOptionsBinary) const;
void SaveEdited(const std::wstring& fileName, const FileOptions fileFormat = FileOptions::fileOptionsBinary);
private:
void SaveToFileImpl(const std::wstring& fileName, const FileOptions fileFormat) const;
public:
Expand Down Expand Up @@ -171,6 +172,7 @@ class ComputationNetwork : public ScriptableObjects::Object, public ScriptableOb
private:
void DetermineSetOfAllRoots();
void CollectInputAndLearnableParameters(const ComputationNodeBasePtr& rootNode);
bool IsCompiled() const { return m_isCompiled; }
void VerifyIsCompiled(const char * where) const;
//bool BuiltAndValidatedSubNetwork(const ComputationNodeBasePtr & rootNode);
public:
Expand Down
23 changes: 16 additions & 7 deletions Source/ComputationNetworkLib/ComputationNetworkEditing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
std::wstring toName,
const CopyNodeFlags flags)
{
InvalidateCompiledNetwork();

if (toName == L"")
toName = fromName;

Expand All @@ -50,11 +52,10 @@ namespace Microsoft { namespace MSR { namespace CNTK {
}
else
{
//node already exists

// node already exists
pToNode = GetNodeFromName(toName);

//same node. no copy needed
// same node. no copy needed
if (pFromNode == pToNode)
LogicError("CopyNode: You are copying the node to the same network with same node name.");
else
Expand All @@ -69,6 +70,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
const std::wstring fromName, std::wstring toNamePrefix,
const CopyNodeFlags flags)
{
InvalidateCompiledNetwork();

if (!(flags & CopyNodeFlags::copyNodeValue))
LogicError("CopySubTree: you cannot copy a tree without copying the node values.");

Expand Down Expand Up @@ -103,7 +106,6 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// nodeNameNew - new node name
void ComputationNetwork::RenameNode(const std::wstring& nodeNameOrig, const std::wstring& nodeNameNew)
{
// so that renamed node will not be referenced
InvalidateCompiledNetwork();

ComputationNodeBasePtr nodeToRename = GetNodeFromName(nodeNameOrig);
Expand All @@ -128,7 +130,6 @@ namespace Microsoft { namespace MSR { namespace CNTK {

void ComputationNetwork::DeleteNode(const std::wstring & nodeName)
{
// so that deleted node will not be referenced
InvalidateCompiledNetwork();

ComputationNodeBasePtr nodeToDelete = GetNodeFromName(nodeName);
Expand Down Expand Up @@ -172,6 +173,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// need to update all the mappings as well childrens
void ComputationNetwork::ChangeNode(wstring nodeName, ComputationNodeBasePtr newNode)
{
InvalidateCompiledNetwork();

ComputationNodeBasePtr oldNode = GetNodeFromName(nodeName);
if (oldNode->OperationName() != newNode->OperationName())
InvalidArgument("newNode must have the same type as the old node.");
Expand Down Expand Up @@ -204,6 +207,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// need to update those nodes who use oldNode as their child
void ComputationNetwork::ReplaceLeafNode(wstring oldNodeName, ComputationNodeBasePtr newNode)
{
InvalidateCompiledNetwork();

ComputationNodeBasePtr oldNode = GetNodeFromName(oldNodeName);

// change the input of those nodes whose child is oldNode
Expand All @@ -223,6 +228,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {

void ComputationNetwork::ReplaceFinalCriterionNode(wstring oldNodeName, ComputationNodeBasePtr newNode)
{
InvalidateCompiledNetwork();

// Checks if the node is a criterion node.
int index = -1;
for (int i = 0; i < m_finalCriteria.size(); ++i)
Expand Down Expand Up @@ -251,6 +258,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {

void ComputationNetwork::AddFeatureNode(ComputationNodeBasePtr featureNode)
{
InvalidateCompiledNetwork();

wstring nodeName = featureNode->NodeName();
if (NodeNameExists(nodeName))
RuntimeError("AddFeatureNode: feature node already exists.");
Expand All @@ -261,12 +270,12 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// We only remove the node, not delete it.
void ComputationNetwork::RemoveFeatureNode(ComputationNodeBasePtr featureNode)
{
InvalidateCompiledNetwork();

wstring nodeName = featureNode->NodeName();
if (!NodeNameExists(nodeName))
RuntimeError("RemoveFeatureNode: feature node does not exist.");

InvalidateCompiledNetwork();

// Removes links.
for (auto nodeIter = m_nameToNodeMap.begin(); nodeIter != m_nameToNodeMap.end(); ++nodeIter)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// verify that network has undergone CompileNetwork()
void ComputationNetwork::VerifyIsCompiled(const char * where) const
{
if (!m_isCompiled)
if (!IsCompiled())
LogicError("%s: A compiled network was expected.", where);
}

Expand Down

0 comments on commit 40cbeac

Please sign in to comment.