Skip to content

Commit

Permalink
SGDOptimizer BS constructor implemented;
Browse files Browse the repository at this point in the history
abstracted out from DoTrain() creation of objects, with two versions one for old CNTK config, and one for BS
  • Loading branch information
frankseide committed Nov 22, 2015
1 parent fefa2ac commit 4140505
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 14 deletions.
1 change: 0 additions & 1 deletion Common/Include/ScriptableObjects.h
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,6 @@ namespace Microsoft { namespace MSR { namespace ScriptableObjects {
bool Exists(const wstring & id) const { return Find(id) != nullptr; }
static const IConfigRecord & Record();
template<class V> static const std::vector<typename V::value_type> & Array(const V & vec);

};
typedef shared_ptr<struct IConfigRecord> IConfigRecordPtr;

Expand Down
38 changes: 27 additions & 11 deletions MachineLearning/CNTK/CNTK.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,21 @@ function<ComputationNetworkPtr(DEVICEID_TYPE)> GetCreateNetworkFn(const Scriptab
}
function<ComputationNetworkPtr(DEVICEID_TYPE)> GetCreateNetworkFn(const ConfigParameters &) { NOT_IMPLEMENTED; } // old CNTK config does not support lambdas

// function to create an object of a certain type, using both old CNTK config and BrainScript
template<class C>
shared_ptr<C> CreateObject(const ScriptableObjects::IConfigRecord & config, const wchar_t * id)
{
// TODO: CNTK config added "traceLevel = 0" to 'config'. In BS, we cannot do that (IConfigRecord is immutable). Solution: Just say "traceLevel = 0" in the BS macros for readers.
return config[id].AsPtr<C>(); // BS instantiates this object through this call
}
template<class C>
shared_ptr<C> CreateObject(const ConfigParameters & config, const wchar_t * id)
{
ConfigParameters readerConfig(config(id));
readerConfig.Insert("traceLevel", config(L"traceLevel", "0")); // TODO: fix this by adding it to all config blocks. Easy to fix in BS as 'config with [ traceLevel = 0 ]'.
return make_shared<C>(readerConfig); // old CNTK config specifies a dictionary which then must be explicitly instantiated
}

template <class ConfigRecordType, typename ElemType>
void DoTrain(const ConfigRecordType & config)
{
Expand Down Expand Up @@ -913,23 +928,24 @@ void DoTrain(const ConfigRecordType & config)
RuntimeError("No network builder found in the config file. NDLNetworkBuilder or SimpleNetworkBuilde must be specified");
}

// BUGBUG: inconsistency with BrainScript: old config passes a config dict, whereas BrainScript creates the object right away
const ConfigRecordType & readerConfig(config(L"reader"));
//readerConfig.Insert("traceLevel", config(L"traceLevel", "0")); // TODO: fix this by adding it to all config blocks. Easy to fix in BS as 'config with [ traceLevel = 0 ]'.
auto dataReader = make_shared<DataReader<ElemType>>(readerConfig);
auto dataReader = CreateObject<DataReader<ElemType>>(config, L"reader");

shared_ptr<DataReader<ElemType>> cvDataReader;
if (config.Exists(L"cvReader"))
cvDataReader = CreateObject<DataReader<ElemType>>(config, L"cvReader");

shared_ptr<SGD<ElemType>> optimizer;
if (config.Exists(L"optimizer"))
{
const ConfigRecordType & cvReaderConfig(config(L"cvReader"));
//cvReaderConfig.Insert("traceLevel", config(L"traceLevel", "0"));
cvDataReader = unique_ptr<DataReader<ElemType> >{ new DataReader<ElemType>(cvReaderConfig) };
optimizer = CreateObject<SGD<ElemType>>(config, L"optimizer");
}
else // legacy CNTK config syntax: needs a record called 'SGD'
{
const ConfigRecordType & configSGD(config(L"SGD"));
optimizer = make_shared<SGD<ElemType>>(SGDParams(configSGD, sizeof(ElemType)));
}

const ConfigRecordType & configSGD(config(L"SGD"));
SGD<ElemType> sgd(SGDParams(configSGD, sizeof(ElemType)));

sgd.Train(createNetworkFn, deviceId, dataReader.get(), cvDataReader.get(), makeMode);
optimizer->Train(createNetworkFn, deviceId, dataReader.get(), cvDataReader.get(), makeMode);
}

namespace Microsoft { namespace MSR { namespace ScriptableObjects {
Expand Down
4 changes: 2 additions & 2 deletions MachineLearning/CNTKSGDLib/SGD.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2621,7 +2621,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
template class SGD<float>;
template class SGD<double>;

// register ComputationNode with the ScriptableObject system
ScriptableObjects::ConfigurableRuntimeTypeRegister::Add<SGDParams> registerComputationNode(L"SGDParams");
// register SGD<> with the ScriptableObject system
ScriptableObjects::ConfigurableRuntimeTypeRegister::AddFloatDouble<SGD<float>,SGD<double>> registerSGDOptimizer(L"SGDOptimizer");

}}}
4 changes: 4 additions & 0 deletions MachineLearning/CNTKSGDLib/SGD.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,10 @@ class SGD : public SGDParams
{ }
// note: This must be in the header, as we cannot properly specialize this constructor in the CPP to make sure all versions are generated.

SGD(const ScriptableObjects::IConfigRecordPtr configp) :
SGD(*configp)
{ }

SGD(SGDParams&& sgdParams);

void Train(function<ComputationNetworkPtr(DEVICEID_TYPE)> createNetworkFn, DEVICEID_TYPE deviceId,
Expand Down

0 comments on commit 4140505

Please sign in to comment.