Skip to content

Commit

Permalink
Fixed redundant writing of the .0 model file when no PreCompute is pe…
Browse files Browse the repository at this point in the history
…rformed and the model is loaded from a checkpoint
  • Loading branch information
amitaga committed Feb 4, 2016
1 parent a930d1c commit d3e636c
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
26 changes: 16 additions & 10 deletions Source/SGDLib/SGD.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,15 @@ void SGD<ElemType>::Train(function<ComputationNetworkPtr(DEVICEID_TYPE)> createN
}

wstring modelFileName = GetModelNameForEpoch(int(startEpoch) - 1);
bool loadNetworkFromCheckpoint = false;
if (startEpoch >= 0)
{
loadNetworkFromCheckpoint = true;
fprintf(stderr, "Starting from checkpoint. Load Network From File %ls.\n", modelFileName.c_str());
}

// create or load from checkpoint
shared_ptr<ComputationNetwork> net = startEpoch < 0 ? createNetworkFn(deviceId) : ComputationNetwork::CreateFromFile<ElemType>(deviceId, modelFileName);
shared_ptr<ComputationNetwork> net = !loadNetworkFromCheckpoint ? createNetworkFn(deviceId) : ComputationNetwork::CreateFromFile<ElemType>(deviceId, modelFileName);

// log the device we are computing on
if (net->GetDeviceId() < 0)
Expand All @@ -68,7 +72,7 @@ void SGD<ElemType>::Train(function<ComputationNetworkPtr(DEVICEID_TYPE)> createN
startEpoch = max(startEpoch, 0);
m_needAdaptRegularization = false;

TrainOrAdaptModel(startEpoch, net, net, nullptr, trainSetDataReader, validationSetDataReader);
TrainOrAdaptModel(startEpoch, net, loadNetworkFromCheckpoint, net, nullptr, trainSetDataReader, validationSetDataReader);
}

// -----------------------------------------------------------------------
Expand All @@ -89,11 +93,13 @@ void SGD<ElemType>::Adapt(wstring origModelFileName, wstring refNodeName,
}

ComputationNetworkPtr net;
bool networkLoadedFromCheckpoint = false;
if (startEpoch >= 0)
{
wstring modelFileName = GetModelNameForEpoch(int(startEpoch) - 1);
fprintf(stderr, "Starting from checkpoint. Load Network From File %ls.\n", modelFileName.c_str());
net = ComputationNetwork::CreateFromFile<ElemType>(deviceId, modelFileName);
networkLoadedFromCheckpoint = true;
}
else
{
Expand All @@ -120,7 +126,7 @@ void SGD<ElemType>::Adapt(wstring origModelFileName, wstring refNodeName,
refNode = refNet->GetNodeFromName(refNodeName);
}

TrainOrAdaptModel(startEpoch, net, refNet, refNode, trainSetDataReader, validationSetDataReader);
TrainOrAdaptModel(startEpoch, net, networkLoadedFromCheckpoint, refNet, refNode, trainSetDataReader, validationSetDataReader);
}

// -----------------------------------------------------------------------
Expand All @@ -131,6 +137,7 @@ static double MomentumPerMB(double momentumPerSample, size_t minibatchSize);

template <class ElemType>
void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
bool networkLoadedFromCheckpoint,
ComputationNetworkPtr refNet,
ComputationNodeBasePtr refNode,
IDataReader<ElemType>* trainSetDataReader,
Expand Down Expand Up @@ -259,7 +266,9 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
InitDistGradAgg(evaluationNodes.size(), m_traceLevel);
}
// precompute mean and invStdDev nodes and save initial model
if (PreCompute(net, trainSetDataReader, featureNodes, labelNodes, inputMatrices) || startEpoch == 0)
// When no precompute, only save if we did not load the model from a
// checkpoint but instead built it from a network description
if (PreCompute(net, trainSetDataReader, featureNodes, labelNodes, inputMatrices) || !networkLoadedFromCheckpoint)
{
// Synchronize all ranks before writing the model to ensure that
// everyone is done loading the model
Expand All @@ -268,8 +277,7 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
g_mpi->WaitAll();
}

if ((g_mpi == nullptr) || g_mpi->IsMainNode())
net->Save(GetModelNameForEpoch(int(startEpoch) - 1));
net->Save(GetModelNameForEpoch(int(startEpoch) - 1));
}

bool learnRateInitialized = false;
Expand Down Expand Up @@ -363,8 +371,7 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
i + 1, learnRatePerSample, m_minLearnRate);
if (m_autoLearnRateSearchType != LearningRateSearchAlgorithm::None)
{
if ((g_mpi == nullptr) || g_mpi->IsMainNode())
net->Save(m_modelPath);
net->Save(m_modelPath);
}
break;
}
Expand Down Expand Up @@ -567,8 +574,7 @@ void SGD<ElemType>::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
}
else
{
if ((g_mpi == nullptr) || g_mpi->IsMainNode())
net->Save(GetModelNameForEpoch(i, true));
net->Save(GetModelNameForEpoch(i, true));

fprintf(stderr, "Finished training and saved final model\n\n");
break;
Expand Down
1 change: 1 addition & 0 deletions Source/SGDLib/SGD.h
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ class SGD : public SGDParams
std::vector<ComputationNodeBasePtr>& GetEvalCriterionNodes(ComputationNetworkPtr net);

void TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net,
bool networkLoadedFromCheckpoint,
ComputationNetworkPtr refNet,
ComputationNodeBasePtr refNode,
IDataReader<ElemType>* trainSetDataReader,
Expand Down

0 comments on commit d3e636c

Please sign in to comment.