Skip to content

Commit

Permalink
changed most uses of ComputationNetwork * and & to ComputationNetwork…
Browse files Browse the repository at this point in the history
…Ptr, to eliminate ownership bugs and allow integration with BS. Also allowed for some minor code simplifications;

made IComputationNetBuilder::LoadNetworkFromFile() 'protected' since it is no longer used. Will be deleted soon
  • Loading branch information
frankseide committed Nov 21, 2015
1 parent 38ca2aa commit 04e0621
Show file tree
Hide file tree
Showing 26 changed files with 1,685 additions and 1,712 deletions.
53 changes: 18 additions & 35 deletions MachineLearning/CNTK/CNTK.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,7 @@ void DoEvalBase(const ConfigParameters& config, IDataReader<ElemType>& reader)
evalNodeNamesVector.push_back(evalNodeNames[i]);
}

ComputationNetwork net(deviceId);
net.LoadFromFile<ElemType>(modelPath);
net.ResetEvalTimeStamp();
auto net = ComputationNetwork::CreateFromFile<ElemType>(deviceId, modelPath);

SimpleEvaluator<ElemType> eval(net, numMBsToShowResult, traceLevel);
eval.Evaluate(&reader, evalNodeNamesVector, mbSize[0], epochSize);
Expand Down Expand Up @@ -180,9 +178,7 @@ void DoEvalUnroll(const ConfigParameters& config)
intargvector mbSize = minibatchSize;
wstring path2EvalResults = config(L"path2EvalResults", L"");

ComputationNetwork net(deviceId);
net.LoadFromFile<ElemType>(modelPath);
net.ResetEvalTimeStamp();
auto net = ComputationNetwork::CreateFromFile<ElemType>(deviceId, modelPath);

MultiNetworksEvaluator<ElemType> eval(net);
double evalEntropy;
Expand Down Expand Up @@ -244,9 +240,7 @@ void DoCrossValidate(const ConfigParameters& config)
}

cvModels.push_back(cvModelPath);
ComputationNetwork net(deviceId);
net.LoadFromFile<ElemType>(cvModelPath);
net.ResetEvalTimeStamp();
auto net = ComputationNetwork::CreateFromFile<ElemType>(deviceId, cvModelPath);

SimpleEvaluator<ElemType> eval(net, numMBsToShowResult, traceLevel);

Expand Down Expand Up @@ -320,9 +314,7 @@ void DoWriteOutput(const ConfigParameters& config)
outputNodeNamesVector.push_back(outputNodeNames[i]);
}

ComputationNetwork net(deviceId);
net.LoadFromFile<ElemType>(modelPath);
net.ResetEvalTimeStamp();
auto net = ComputationNetwork::CreateFromFile<ElemType>(deviceId, modelPath);

SimpleOutputWriter<ElemType> writer(net, 1);

Expand Down Expand Up @@ -803,7 +795,7 @@ class BrainScriptNetworkBuilder : public IComputationNetBuilder<ElemType>
BrainScriptNetworkBuilder(const ConfigParameters & config) { NOT_IMPLEMENTED; }

// build a ComputationNetwork from description language
virtual /*IComputationNetBuilder::*/ComputationNetwork* BuildNetworkFromDescription(ComputationNetwork* = nullptr) override
virtual /*IComputationNetBuilder::*/ComputationNetworkPtr BuildNetworkFromDescription(ComputationNetwork* = nullptr) override
{
vector<ScriptableObjects::ConfigValuePtr> args; // this lambda has no arguments
ScriptableObjects::ConfigLambda::NamedParams namedArgs;
Expand All @@ -813,7 +805,7 @@ class BrainScriptNetworkBuilder : public IComputationNetBuilder<ElemType>
fprintf(stderr, "BrainScriptNetworkBuilder using CPU\n");
else
fprintf(stderr, "BrainScriptNetworkBuilder using GPU %d\n", (int)m_net->GetDeviceId());
return m_net.get();
return m_net;
}

// load an existing file--this is the same code as for NDLNetworkBuilder.h (OK to copy it here because this is temporary code anyway)
Expand Down Expand Up @@ -876,7 +868,7 @@ void DoTrain(const ConfigRecordType & config)
else if (config.Exists(L"SimpleNetworkBuilder"))
{
const ConfigRecordType & simpleNetworkBuilderConfig(config(L"SimpleNetworkBuilder", ConfigRecordType::Record()));
shared_ptr<IComputationNetBuilder<ElemType>> netBuilder = make_shared<SimpleNetworkBuilder<ElemType>>(simpleNetworkBuilderConfig);
auto netBuilder = make_shared<SimpleNetworkBuilder<ElemType>>(simpleNetworkBuilderConfig);
createNetworkFn = [netBuilder](DEVICEID_TYPE deviceId)
{
return shared_ptr<ComputationNetwork>(netBuilder->BuildNetworkFromDescription());
Expand All @@ -886,7 +878,7 @@ void DoTrain(const ConfigRecordType & config)
else if (config.Exists(L"NDLNetworkBuilder"))
{
const ConfigRecordType & ndlNetworkBuilderConfig(config(L"NDLNetworkBuilder", ConfigRecordType::Record()));
shared_ptr<IComputationNetBuilder<ElemType>> netBuilder = make_shared<NDLBuilder<ElemType>>(ndlNetworkBuilderConfig);
shared_ptr<NDLBuilder<ElemType>> netBuilder = make_shared<NDLBuilder<ElemType>>(ndlNetworkBuilderConfig);
createNetworkFn = [netBuilder](DEVICEID_TYPE deviceId)
{
return shared_ptr<ComputationNetwork>(netBuilder->BuildNetworkFromDescription());
Expand Down Expand Up @@ -1063,7 +1055,7 @@ void DoEncoderDecoder(const ConfigParameters& config)
validationDataReader.push_back(cvEncoderDataReader);
validationDataReader.push_back(cvDecoderDataReader);

sgd.EncoderDecoder(netBuilders, trainDataReader, validationDataReader, makeMode);
sgd.EncoderDecoder(netBuilders, (int)config(L"deviceId"), trainDataReader, validationDataReader, makeMode);

delete encoderDataReader;
delete decoderDataReader;
Expand Down Expand Up @@ -1149,7 +1141,7 @@ void DoBidirectionEncoderDecoder(const ConfigParameters& config)
validationDataReader.push_back(cvDecoderDataReader);
validationDataReader.push_back(cvBackwardDecoderDataReader);

sgd.EncoderDecoder(netBuilders, trainDataReader, validationDataReader, makeMode);
sgd.EncoderDecoder(netBuilders, (int)config(L"deviceId"), trainDataReader, validationDataReader, makeMode);

delete encoderDataReader;
delete decoderDataReader;
Expand Down Expand Up @@ -1198,17 +1190,13 @@ void DoEvalEncodingBeamSearchDecoding(const ConfigParameters& config)
int traceLevel = config(L"traceLevel", "0");
size_t numMBsToShowResult = config(L"numMBsToShowResult", "100");

vector<ComputationNetwork*> nets;
ComputationNetwork encoderNet(deviceId);
encoderNet.LoadFromFile<ElemType>(encoderModelPath, FileOptions::fileOptionsBinary, true);
encoderNet.ResetEvalTimeStamp();
vector<ComputationNetworkPtr> nets;
auto encoderNet = ComputationNetwork::CreateFromFile<ElemType>(deviceId, encoderModelPath, FileOptions::fileOptionsBinary, true);

ComputationNetwork decoderNet(deviceId);
decoderNet.LoadFromFile<ElemType>(decoderModelPath, FileOptions::fileOptionsBinary, false, &encoderNet);
decoderNet.ResetEvalTimeStamp();
auto decoderNet = ComputationNetwork::CreateFromFile<ElemType>(deviceId, decoderModelPath, FileOptions::fileOptionsBinary, false, encoderNet.get());

nets.push_back(&encoderNet);
nets.push_back(&decoderNet);
nets.push_back(encoderNet);
nets.push_back(decoderNet);
ConfigArray evalNodeNames = config(L"evalNodeNames");
vector<wstring> evalNodeNamesVector;
for (int i = 0; i < evalNodeNames.size(); ++i)
Expand Down Expand Up @@ -1273,9 +1261,7 @@ void DoEvalBeamSearch(const ConfigParameters& config, IDataReader<ElemType>& rea
int traceLevel = config(L"traceLevel", "0");
size_t numMBsToShowResult = config(L"numMBsToShowResult", "100");

ComputationNetwork net(deviceId);
net.LoadFromFile<ElemType>(modelPath);
net.ResetEvalTimeStamp();
auto net = ComputationNetwork::CreateFromFile<ElemType>(deviceId, modelPath);

ConfigArray evalNodeNames = config(L"evalNodeNames");
vector<wstring> evalNodeNamesVector;
Expand Down Expand Up @@ -1365,15 +1351,12 @@ void DoEdit(const ConfigParameters& config)
template <typename ElemType>
void DoConvertFromDbn(const ConfigParameters& config)
{
//config.Insert("deviceId","-1"); //force using CPU

wstring modelPath = config(L"modelPath");
wstring dbnModelPath = config(L"dbnModelPath");

IComputationNetBuilder<ElemType>* netBuilder = (IComputationNetBuilder<ElemType>*)new SimpleNetworkBuilder<ElemType>(config);
ComputationNetwork* net = netBuilder->LoadNetworkFromFile(dbnModelPath);
auto netBuilder = make_shared<SimpleNetworkBuilder<ElemType>>(config);
ComputationNetworkPtr net = netBuilder->BuildNetworkFromDbnFile(dbnModelPath);
net->SaveToFile(modelPath);
delete (netBuilder);
}

// do topological plot of computation network
Expand Down
6 changes: 4 additions & 2 deletions MachineLearning/CNTK/ExperimentalNetworkBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@
//BinaryStandardNode(TransposeTimesNode)
;

#if 0 // no longer needed
namespace Microsoft { namespace MSR { namespace CNTK {

using namespace Microsoft::MSR;
Expand All @@ -137,7 +138,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {

// build a ComputationNetwork from BrainScript source code
template<class ElemType>
/*virtual*/ /*IComputationNetBuilder::*/ComputationNetwork* ExperimentalNetworkBuilder<ElemType>::BuildNetworkFromDescription(ComputationNetwork*)
/*virtual*/ /*IComputationNetBuilder::*/ComputationNetworkPtr ExperimentalNetworkBuilder<ElemType>::BuildNetworkFromDescription(ComputationNetwork*)
{
if (!m_net || m_net->GetTotalNumberOfNodes() < 1) //not built yet
{
Expand All @@ -160,10 +161,11 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// TODO: old CNTK code seems to be able to load the network in-place--is that important; is it OK to just replace the pointer?
}
m_net->ResetEvalTimeStamp();
return m_net.get();
return m_net;
}

template class ExperimentalNetworkBuilder<float>;
template class ExperimentalNetworkBuilder<double>;

}}}
#endif
7 changes: 5 additions & 2 deletions MachineLearning/CNTK/ExperimentalNetworkBuilder.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#if 0 // no longer needed
// ExperimentalNetworkBuilder.h -- interface to new version of NDL (and config) parser --fseide

#pragma once
Expand Down Expand Up @@ -29,12 +30,13 @@ namespace Microsoft { namespace MSR { namespace CNTK {

// build a ComputationNetwork from description language
// TODO: change return type of these interfaces to shared_ptrs
virtual /*IComputationNetBuilder::*/ComputationNetwork* BuildNetworkFromDescription(ComputationNetwork* = nullptr);
virtual /*IComputationNetBuilder::*/ComputationNetworkPtr BuildNetworkFromDescription(ComputationNetwork* = nullptr) override;
// TODO: that function argument is related to PairNetworkNode, which will go away (we don't support it here)

// load an existing file--this is the same code as for NDLNetworkBuilder.h (OK to copy it here because this is temporary code anyway)
virtual /*IComputationNetBuilder::*/ComputationNetwork* LoadNetworkFromFile(const wstring& modelFileName, bool forceLoad = true,
bool bAllowNoCriterionNode = false, ComputationNetwork* anotherNetwork = nullptr)
bool bAllowNoCriterionNode = false,
ComputationNetwork* anotherNetwork = nullptr) override
{
if (!m_net || m_net->GetTotalNumberOfNodes() == 0 || forceLoad) //not built or force load
{
Expand All @@ -48,3 +50,4 @@ namespace Microsoft { namespace MSR { namespace CNTK {
};

}}}
#endif
2 changes: 1 addition & 1 deletion MachineLearning/CNTK/IExecutionEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
class IExecutionEngine
{
public:
virtual ComputationNetwork & GetComputationNetwork() = 0;
virtual ComputationNetworkPtr GetComputationNetwork() = 0;

virtual NDLNodeEvaluator<ElemType> & GetNodeEvaluator() = 0;

Expand Down
14 changes: 7 additions & 7 deletions MachineLearning/CNTK/ModelEditLanguage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ void MELScript<ElemType>::CallFunction(const std::string& p_name, const ConfigPa
if (params.size() > numFixedParams + numOptionalParams || params.size() < numFixedParams)
RuntimeError("Invalid number of parameters. Valid parameters: CreateModel(). newly created model always becomes the new default.");

ComputationNetwork* cn = new ComputationNetwork(CPUDEVICE);
auto cn = make_shared<ComputationNetwork>(CPUDEVICE);
OverrideModelNameAndSetDefaultModel(cn);
}
if (EqualInsensitive(name, "CreateModelWithName")) //create a blank model
Expand All @@ -113,7 +113,7 @@ void MELScript<ElemType>::CallFunction(const std::string& p_name, const ConfigPa
if (params.size() > numFixedParams + numOptionalParams || params.size() < numFixedParams)
RuntimeError("Invalid number of parameters. Valid parameters: CreateModelWithName(modelName). newly created model always becomes the new default.");

ComputationNetwork* cn = new ComputationNetwork(CPUDEVICE);
auto cn = make_shared<ComputationNetwork>(CPUDEVICE);
OverrideModelNameAndSetDefaultModel(cn, params[0]);
}
else if (EqualInsensitive(name, "LoadModel"))
Expand All @@ -124,7 +124,7 @@ void MELScript<ElemType>::CallFunction(const std::string& p_name, const ConfigPa

std::wstring modelFormat = GetOptionalModelFormat(params, numFixedParams);

ComputationNetwork* cn = new ComputationNetwork(CPUDEVICE);
auto cn = make_shared<ComputationNetwork>(CPUDEVICE);
cn->LoadFromFile<ElemType>(params[0]);
OverrideModelNameAndSetDefaultModel(cn);
}
Expand All @@ -136,7 +136,7 @@ void MELScript<ElemType>::CallFunction(const std::string& p_name, const ConfigPa

std::wstring modelFormat = GetOptionalModelFormat(params, numFixedParams);

ComputationNetwork* cn = new ComputationNetwork(CPUDEVICE);
auto cn = make_shared<ComputationNetwork>(CPUDEVICE);
cn->LoadFromFile<ElemType>(params[1]);
OverrideModelNameAndSetDefaultModel(cn, params[0]);
}
Expand All @@ -148,7 +148,7 @@ void MELScript<ElemType>::CallFunction(const std::string& p_name, const ConfigPa

string modelName = params[0];
wstring ndlSnippetFileName = params[1];
ComputationNetwork* cn = new ComputationNetwork(CPUDEVICE);
auto cn = make_shared<ComputationNetwork>(CPUDEVICE);
NDLScript<ElemType> script;
ConfigParameters ndlScript (script.ReadConfigFile(ndlSnippetFileName));

Expand Down Expand Up @@ -181,7 +181,7 @@ void MELScript<ElemType>::CallFunction(const std::string& p_name, const ConfigPa

std::wstring fileName = params[0];

ComputationNetwork* cn = m_netNdlDefault->cn;
auto cn = m_netNdlDefault->cn;
if (cn == NULL)
RuntimeError("SaveDefaultModel can only be called after a default name exists (i.e., at least one model is loaded.)");

Expand Down Expand Up @@ -440,7 +440,7 @@ void MELScript<ElemType>::CallFunction(const std::string& p_name, const ConfigPa
// this probabably won't do anything, but make sure all NDL has been created
ProcessNDLScript(netNdl, ndlPassInitial, false);

ComputationNetwork* cn = netNdl->cn;
auto cn = netNdl->cn;
for (auto & node : nodes)
{
switch(prop)
Expand Down
8 changes: 4 additions & 4 deletions MachineLearning/CNTK/ModelEditLanguage.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class MELScript: public ConfigParser
search = symbol.substr(firstStart);
}

ComputationNetwork* cn = netNdl->cn;
ComputationNetworkPtr cn = netNdl->cn;
wstring name = msra::strfun::utf16(search);
vector<ComputationNodeBasePtr> nodes = cn->GetNodesFromName(name);
// didn't find the name in the current symbols, try NDL
Expand Down Expand Up @@ -378,7 +378,7 @@ class MELScript: public ConfigParser
}
}

void OverrideModelNameAndSetDefaultModel(ComputationNetwork* cn, string modelName = "default")
void OverrideModelNameAndSetDefaultModel(ComputationNetworkPtr cn, string modelName = "default")
{
auto found = m_mapNameToNetNdl.find(modelName);
if (found != m_mapNameToNetNdl.end() && found->second.cn != cn)
Expand Down Expand Up @@ -583,7 +583,7 @@ class MELScript: public ConfigParser
// EvaluateNDLSnippet - evaluate the passed snippet of NDL into a computational network
// script - [in] text of the NDL snippet
// network - [in/out] computation network to insert NDL into
void EvaluateNDLSnippet(const ConfigValue& script, ComputationNetwork* network)
void EvaluateNDLSnippet(const ConfigValue& script, ComputationNetworkPtr network)
{
NDLUtil<ElemType> ndlUtil(network);
ndlUtil.ProcessNDLConfig(script);
Expand Down Expand Up @@ -646,7 +646,7 @@ class MELScript: public ConfigParser
// model1=[...] - Embedded NDL script
if (0 == foundBrace)
{
ComputationNetwork* cn = new ComputationNetwork();
ComputationNetworkPtr cn = make_shared<ComputationNetwork>();
EvaluateNDLSnippet(rightValue, cn);
OverrideModelNameAndSetDefaultModel(cn, key);
}
Expand Down
16 changes: 8 additions & 8 deletions MachineLearning/CNTK/NDLNetworkBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ namespace Microsoft { namespace MSR { namespace CNTK {
const ConfigParameters* m_baseConfig; // NOTE: the lifetime of the parent MUST exist from the call to Init to the BuildNetworkFromDescription() call for stringize

public:
NDLBuilder() : m_net(nullptr)
NDLBuilder()
{
m_executionEngine = NULL;
m_baseConfig = NULL;
} // empty constructor, call Init immediately hereafter

NDLBuilder(const ConfigParameters& config) : m_net(nullptr)
NDLBuilder(const ConfigParameters& config)
{
m_baseConfig = config.GetParent();
Init(config);
Expand All @@ -57,7 +57,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
m_dumpFileName = dumpFileName;
m_initialConfig = configParams;
m_deviceId = deviceId;
m_net = &(executionEngine->GetComputationNetwork());
m_net = executionEngine->GetComputationNetwork();
if (m_deviceId == AUTOPLACEMATRIX)
m_deviceId = Matrix<ElemType>::GetBestGPUDeviceId();
m_deviceId = EnforceOneGPUOnly(m_deviceId); // see EnforceOneGPUOnly() for comment on what this is
Expand Down Expand Up @@ -158,16 +158,16 @@ namespace Microsoft { namespace MSR { namespace CNTK {
}

virtual ComputationNetwork* LoadNetworkFromFile(const wstring& modelFileName, bool forceLoad = true,
bool bAllowNoCriterionNode = false, ComputationNetwork* anotherNetwork = nullptr)
bool bAllowNoCriterionNode = false, ComputationNetwork* anotherNetwork = nullptr) override
{
if (m_net->GetTotalNumberOfNodes() == 0 || forceLoad) //not built or force load
m_net->LoadFromFile<ElemType>(modelFileName, FileOptions::fileOptionsBinary, bAllowNoCriterionNode, anotherNetwork);

m_net->ResetEvalTimeStamp();
return m_net;
return m_net.get();
}

ComputationNetwork* LoadNetworkFromConfig(const wstring& configFilePaths, bool forceLoad = true)
ComputationNetworkPtr LoadNetworkFromConfig(const wstring& configFilePaths, bool forceLoad = true)
{
if (m_net->GetTotalNumberOfNodes() == 0 || forceLoad) //not built or force load
LoadFromConfig(configFilePaths);
Expand Down Expand Up @@ -214,7 +214,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
ndlUtil.ProcessNDLConfig(config, true);
}

virtual ComputationNetwork* BuildNetworkFromDescription(ComputationNetwork* = nullptr)
virtual ComputationNetworkPtr BuildNetworkFromDescription(ComputationNetwork* = nullptr) override
{
if (m_net->GetTotalNumberOfNodes() < 1) //not built yet
{
Expand All @@ -226,7 +226,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
}

private:
ComputationNetwork* m_net;
ComputationNetworkPtr m_net;
IExecutionEngine<ElemType>* m_executionEngine;
std::wstring m_networkConfig;
std::wstring m_dumpFileName;
Expand Down
Loading

0 comments on commit 04e0621

Please sign in to comment.