Skip to content

Commit

Permalink
replaced StreamMinibatchInputs::operator[] by named searchable access…
Browse files Browse the repository at this point in the history
…or functions
  • Loading branch information
frankseide committed Feb 28, 2016
1 parent 24c88c0 commit f3a720e
Show file tree
Hide file tree
Showing 15 changed files with 121 additions and 96 deletions.
6 changes: 3 additions & 3 deletions Source/ActionsLib/OtherActions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ void DoCreateLabelMap(const ConfigParameters& config)
Matrix<ElemType> featuresMatrix(CPUDEVICE);
Matrix<ElemType> labelsMatrix(CPUDEVICE);
StreamMinibatchInputs<ElemType> matrices;
matrices[featureNames[0]] = &featuresMatrix;
matrices.AddInputMatrix(featureNames[0], &featuresMatrix);
if (labelNames.size() == 0)
RuntimeError("CreateLabelMap: no labels found to process");

Expand All @@ -66,7 +66,7 @@ void DoCreateLabelMap(const ConfigParameters& config)
for (const std::wstring& labelsName : labelNames)
{
// take the last label file defined (the other one might be input)
matrices[labelsName] = &labelsMatrix;
matrices.AddInputMatrix(labelsName, &labelsMatrix);

// get the label mapping file name
ConfigParameters labelConfig(readerConfig(labelsName));
Expand Down Expand Up @@ -97,7 +97,7 @@ void DoCreateLabelMap(const ConfigParameters& config)
int count = 0;
while (dataReader.GetMinibatch(matrices))
{
Matrix<ElemType>& features = *matrices[featureNames[0]];
Matrix<ElemType>& features = matrices.GetInputMatrix(featureNames[0]);
count += features.GetNumCols();
if (traceLevel > 1)
fprintf(stderr, "."); // progress meter
Expand Down
8 changes: 4 additions & 4 deletions Source/CNTK/tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ void TestReader(const ConfigParameters& configBase)
Matrix<ElemType> featuresMatrix(deviceId);
Matrix<ElemType> labelsMatrix(deviceId);
StreamMinibatchInputs<ElemType> matrices;
matrices[featureNames[0]] = &featuresMatrix;
matrices[labelNames[0]] = &labelsMatrix;
matrices.AddInputMatrix(featureNames[0], &featuresMatrix);
matrices.AddInputMatrix(labelNames[0] , &labelsMatrix);

auto start = std::chrono::system_clock::now();
int epochs = config("maxEpochs");
Expand All @@ -131,8 +131,8 @@ void TestReader(const ConfigParameters& configBase)
int i = 0;
while (dataReader.GetMinibatch(matrices))
{
Matrix<ElemType>& features = *matrices[featureNames[0]];
Matrix<ElemType>& labels = *matrices[labelNames[0]];
Matrix<ElemType>& features = matrices.GetInputMatrix(featureNames[0]);
Matrix<ElemType>& labels = matrices.GetInputMatrix(labelNames[0]);

if (labels.GetNumRows() == 0)
{
Expand Down
29 changes: 28 additions & 1 deletion Source/Common/Include/DataReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,34 @@ class StreamMinibatchInputs
typedef map<std::wstring, Matrix<ElemType>*> Map;
Map matrices;
public:
Matrix<ElemType>*& operator[](const std::wstring& nodeName) { return matrices[nodeName]; }
void AddInput(const std::wstring& nodeName, Matrix<ElemType>* matrix) { AddInputMatrix(nodeName, matrix); } // use this where entire entry is copied (UCIFastReader::GetMinibatch() async)
// TODO: GetInput() will return a struct
// access to matrix entries
void AddInputMatrix(const std::wstring& nodeName, Matrix<ElemType>* matrix) { matrices[nodeName] = matrix; }
Matrix<ElemType>* GetInputMatrixPtr(const std::wstring& nodeName) // gets matrix, or NULL if no such entry
{
auto iter = matrices.find(nodeName);
if (iter != matrices.end())
return iter->second;
else
return nullptr;
}
Matrix<ElemType>& GetInputMatrix (const std::wstring& nodeName)
{
auto* matrixp = GetInputMatrixPtr(nodeName);
if (!matrixp)
LogicError("GetInputMatrix: Attempted to access non-existent input stream '%ls'", nodeName.c_str());
else
return *matrixp;
}
// some stuff we should get rid off
void FreeInputMatrix(const std::wstring& nodeName) // called by DecimateMinibatch()
{
delete matrices[nodeName];
matrices[nodeName] = nullptr; // TODO: change ownership handling to using shared_ptrs
}
// iterating
// TODO: Abstract this.
typename Map::iterator begin() { return matrices.begin(); }
typename Map::iterator end() { return matrices.end(); }
typename Map::iterator find(const std::wstring& nodeName) { return matrices.find(nodeName); }
Expand Down
6 changes: 3 additions & 3 deletions Source/Readers/DSSMReader/DSSMReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,9 @@ bool DSSMReader<ElemType>::GetMinibatch(StreamMinibatchInputs<ElemType>& matrice
// In my unit test example, the input matrices contain 5: N, S, fD, fQ and labels
// Both N and S serve as a pre-set constant values, no need to change them
// In this node, we only need to fill in these matrices: fD, fQ, labels
Matrix<ElemType>& featuresQ = *matrices[m_featuresNameQuery];
Matrix<ElemType>& featuresD = *matrices[m_featuresNameDoc];
Matrix<ElemType>& labels = *matrices[m_labelsName]; // will change this part later.
Matrix<ElemType>& featuresQ = matrices.GetInputMatrix(m_featuresNameQuery);
Matrix<ElemType>& featuresD = matrices.GetInputMatrix(m_featuresNameDoc);
Matrix<ElemType>& labels = matrices.GetInputMatrix(m_labelsName); // will change this part later. TODO: How?

size_t actualMBSize = (m_readNextSample + m_mbSize > m_totalSamples) ? m_totalSamples - m_readNextSample : m_mbSize;

Expand Down
12 changes: 6 additions & 6 deletions Source/Readers/HTKMLFReader/HTKMLFReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,7 @@ bool HTKMLFReader<ElemType>::GetMinibatchToTrainOrTest(StreamMinibatchInputs<Ele
{
// dereference matrix that corresponds to key (input/output name) and
// populate based on whether its a feature or a label
Matrix<ElemType>& data = *matrices[iter->first]; // can be features or labels
Matrix<ElemType>& data = matrices.GetInputMatrix(iter->first); // can be features or labels
if (m_nameToTypeMap[iter->first] == InputOutputTypes::real)
{
id = m_featureNameToIdMap[iter->first];
Expand Down Expand Up @@ -1154,7 +1154,7 @@ bool HTKMLFReader<ElemType>::GetMinibatchToTrainOrTest(StreamMinibatchInputs<Ele
{
// dereference matrix that corresponds to key (input/output name) and
// populate based on whether its a feature or a label
Matrix<ElemType>& data = *matrices[iter->first]; // can be features or labels
Matrix<ElemType>& data = matrices.GetInputMatrix(iter->first); // can be features or labels

if (m_nameToTypeMap[iter->first] == InputOutputTypes::real)
{
Expand Down Expand Up @@ -1218,7 +1218,7 @@ bool HTKMLFReader<ElemType>::GetMinibatchToTrainOrTest(StreamMinibatchInputs<Ele
{
// dereference matrix that corresponds to key (input/output name) and
// populate based on whether its a feature or a label
Matrix<ElemType>& data = *matrices[iter->first]; // can be features or labels
Matrix<ElemType>& data = matrices.GetInputMatrix(iter->first); // can be features or labels

if (m_nameToTypeMap[iter->first] == InputOutputTypes::real)
{
Expand Down Expand Up @@ -1355,7 +1355,7 @@ bool HTKMLFReader<ElemType>::GetMinibatchToTrainOrTest(StreamMinibatchInputs<Ele
{
// dereference matrix that corresponds to key (input/output name) and
// populate based on whether its a feature or a label
Matrix<ElemType>& data = *matrices[iter->first]; // can be features or labels
Matrix<ElemType>& data = matrices.GetInputMatrix(iter->first); // can be features or labels
if (m_nameToTypeMap[iter->first] == InputOutputTypes::real)
{
id = m_featureNameToIdMap[iter->first];
Expand Down Expand Up @@ -1398,7 +1398,7 @@ void HTKMLFReader<ElemType>::fillOneUttDataforParallelmode(StreamMinibatchInputs
{
// dereference matrix that corresponds to key (input/output name) and
// populate based on whether its a feature or a label
Matrix<ElemType>& data = *matrices[iter->first]; // can be features or labels
Matrix<ElemType>& data = matrices.GetInputMatrix(iter->first); // can be features or labels

if (m_nameToTypeMap[iter->first] == InputOutputTypes::real)
{
Expand Down Expand Up @@ -1512,7 +1512,7 @@ bool HTKMLFReader<ElemType>::GetMinibatchToWrite(StreamMinibatchInputs<ElemType>

if (m_nameToTypeMap.find(iter->first) != m_nameToTypeMap.end() && m_nameToTypeMap[iter->first] == InputOutputTypes::real)
{
Matrix<ElemType>& data = *matrices[iter->first]; // can be features or labels (TODO: Really? Didn't we just ^^^ check that it is 'real'?)
Matrix<ElemType>& data = matrices.GetInputMatrix(iter->first); // can be features or labels (TODO: Really? Didn't we just ^^^ check that it is 'real'?)
size_t id = m_featureNameToIdMap[iter->first];
size_t dim = m_featureNameToDimMap[iter->first];

Expand Down
18 changes: 9 additions & 9 deletions Source/Readers/Kaldi2Reader/HTKMLFReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1369,7 +1369,7 @@ void HTKMLFReader<ElemType>::CopyMinibatchFromBufferToMatrix(
// Copies data to the matrix.
for (auto iter = matrices.begin(); iter != matrices.end(); iter++)
{
Matrix<ElemType>& data = *matrices[iter->first];
Matrix<ElemType>& data = matrices.GetInputMatrix(iter->first);
if (m_nameToTypeMap[iter->first] == InputOutputTypes::real)
{
size_t id = m_featureNameToIdMap[iter->first];
Expand Down Expand Up @@ -1400,16 +1400,16 @@ void HTKMLFReader<ElemType>::CopyMinibatchFromBufferToMatrix(
{
if (data.GetNumCols() != m_currentMBSize * m_numberOfuttsPerMinibatch)
{
matrices[iter->first]->Resize(data.GetNumRows(),
m_currentMBSize * m_numberOfuttsPerMinibatch);
matrices.GetInputMatrix(iter->first).Resize(data.GetNumRows(),
m_currentMBSize * m_numberOfuttsPerMinibatch);
}
matrices[iter->first]->SetValue(0);
matrices.GetInputMatrix(iter->first).SetValue(0);
}
else
{
m_uttDerivBuffer->GetDerivative(m_minibatchUttInfo,
m_pMBLayout,
matrices[iter->first]);
matrices.GetInputMatrixPtr(iter->first)); // TODO: use a reference instead of a ptr
}
}
else if (m_nameToTypeMap[iter->first] == InputOutputTypes::readerObj)
Expand All @@ -1425,7 +1425,7 @@ void HTKMLFReader<ElemType>::CopyMinibatchFromBufferToMatrix(
else
{
m_uttDerivBuffer->GetObjective(m_minibatchUttInfo,
matrices[iter->first]);
matrices.GetInputMatrixPtr(iter->first)); // TODO: use a reference instead of a ptr
}
}
}
Expand All @@ -1447,7 +1447,7 @@ void HTKMLFReader<ElemType>::CopyMinibatchToMatrix(
{
for (auto iter = matrices.begin(); iter != matrices.end(); iter++)
{
Matrix<ElemType>& data = *matrices[iter->first];
Matrix<ElemType>& data = matrices.GetInputMatrix(iter->first);
if (m_nameToTypeMap.at(iter->first) == InputOutputTypes::real)
{
size_t id = m_featureNameToIdMap.at(iter->first);
Expand Down Expand Up @@ -1590,7 +1590,7 @@ bool HTKMLFReader<ElemType>::GetMinibatchToWrite(StreamMinibatchInputs<ElemType>

if (m_nameToTypeMap.find(iter->first) != m_nameToTypeMap.end() && m_nameToTypeMap[iter->first] == InputOutputTypes::real)
{
Matrix<ElemType>& data = *matrices[iter->first]; // can be features or labels
Matrix<ElemType>& data = matrices.GetInputMatrix(iter->first); // can be features or labels
size_t id = m_featureNameToIdMap[iter->first];
size_t dim = m_featureNameToDimMap[iter->first];

Expand Down Expand Up @@ -1643,7 +1643,7 @@ bool HTKMLFReader<ElemType>::GetMinibatchToWrite(StreamMinibatchInputs<ElemType>
}
else
{ // Resizes other inputs so they won't affect actual minibatch size.
Matrix<ElemType>& data = *matrices[iter->first];
Matrix<ElemType>& data = matrices.GetInputMatrix(iter->first);
data.Resize(data.GetNumRows(), 1);
}
}
Expand Down
26 changes: 15 additions & 11 deletions Source/Readers/LMSequenceReader/SequenceReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -987,7 +987,7 @@ void SequenceReader<ElemType>::GetLabelOutput(StreamMinibatchInputs<ElemType>& m
FailBecauseDeprecated(__FUNCTION__); // DEPRECATED CLASS, SHOULD NOT BE USED ANYMORE

size_t j = 0;
Matrix<ElemType>* labels = matrices[m_labelsName[labelInfoOut]];
Matrix<ElemType>* labels = matrices.GetInputMatrixPtr(m_labelsName[labelInfoOut]);
if (labels == nullptr)
return;

Expand Down Expand Up @@ -1033,7 +1033,7 @@ void SequenceReader<ElemType>::GetInputProb(StreamMinibatchInputs<ElemType>& mat
{
FailBecauseDeprecated(__FUNCTION__); // DEPRECATED CLASS, SHOULD NOT BE USED ANYMORE

Matrix<ElemType>* idx2prob = matrices[STRIDX2PROB];
Matrix<ElemType>* idx2prob = matrices.GetInputMatrixPtr(STRIDX2PROB);
if (idx2prob == nullptr)
return;

Expand All @@ -1059,10 +1059,11 @@ void SequenceReader<ElemType>::GetInputProb(StreamMinibatchInputs<ElemType>& mat
m_idx2probRead = true;
}

// TODO: Document what this is. It seems we can fill specific hard-coded inputs with something interesting.
template <class ElemType>
void SequenceReader<ElemType>::GetInputToClass(StreamMinibatchInputs<ElemType>& matrices)
{
Matrix<ElemType>* idx2cls = matrices[STRIDX2CLS];
Matrix<ElemType>* idx2cls = matrices.GetInputMatrixPtr(STRIDX2CLS);
if (idx2cls == nullptr)
return;

Expand Down Expand Up @@ -1197,7 +1198,7 @@ bool SequenceReader<ElemType>::GetMinibatch(StreamMinibatchInputs<ElemType>& mat

// loop through all the samples
int j = 0;
Matrix<ElemType>& features = *matrices[m_featuresName];
Matrix<ElemType>& features = matrices.GetInputMatrix(m_featuresName);
if (matrices.find(m_featuresName) != matrices.end())
{
if (features.GetMatrixType() == MatrixType::DENSE)
Expand Down Expand Up @@ -1250,15 +1251,15 @@ bool SequenceReader<ElemType>::GetMinibatch(StreamMinibatchInputs<ElemType>& mat
// get the features array
if (matrices.find(m_featuresName) == matrices.end())
{
Matrix<ElemType>& nbs = *matrices[L"numberobs"];
Matrix<ElemType>& nbs = matrices.GetInputMatrix(L"numberobs");
int curDevId = nbs.GetDeviceId();
nbs.TransferFromDeviceToDevice(curDevId, CPUDEVICE, true, false, false);
nbs(0, 0) = (float) actualmbsize;
nbs.TransferFromDeviceToDevice(CPUDEVICE, curDevId, true, false, false);
for (size_t i = 0; i < actualmbsize; i++)
{
std::wstring ws = msra::strfun::wstrprintf(L"feature%d", i);
Matrix<ElemType>& features = *matrices[ws];
Matrix<ElemType>& features = matrices.GetInputMatrix(ws);
features.SetValue(labelInfo.dim, 1, features.GetDeviceId(), &m_featuresBuffer[i * labelInfo.dim], matrixFlagNormal);
}
}
Expand All @@ -1274,18 +1275,21 @@ bool SequenceReader<ElemType>::GetMinibatch(StreamMinibatchInputs<ElemType>& mat
{
if (matrices.find(m_labelsName[labelInfoOut]) == matrices.end())
{
// TODO: What is this? Debug code?
for (size_t i = 0; i < actualmbsize; i++)
{
std::wstring ws = msra::strfun::wstrprintf(L"label%d", i);
Matrix<ElemType>* labels = matrices[ws];
labels->SetValue(labelInfo.dim, 1, labels->GetDeviceId(), &m_labelsBuffer[i * labelInfo.dim], matrixFlagNormal);
// This writes into nodes named "labelN", or crashes if they don't exist. Seems this is dead code.
Matrix<ElemType>& labels = matrices.GetInputMatrix(ws);
labels.SetValue(labelInfo.dim, 1, labels.GetDeviceId(), &m_labelsBuffer[i * labelInfo.dim], matrixFlagNormal);
}
}
// BUGBUG? If category labels then this will not output anything if such node is given.
}
else if (labelInfo.type != labelNone)
{
Matrix<ElemType>* labels = matrices[m_labelsName[labelInfoOut]];
labels->SetValue(1, actualmbsize, labels->GetDeviceId(), m_labelsBuffer, matrixFlagNormal);
Matrix<ElemType>& labels = matrices.GetInputMatrix(m_labelsName[labelInfoOut]);
labels.SetValue(1, actualmbsize, labels.GetDeviceId(), m_labelsBuffer, matrixFlagNormal);
}
}
catch (...)
Expand Down Expand Up @@ -2077,7 +2081,7 @@ void BatchSequenceReader<ElemType>::GetLabelOutput(StreamMinibatchInputs<ElemTyp
size_t mbStartSample, size_t actualmbsize)
{
size_t j = 0;
Matrix<ElemType>* labels = matrices[m_labelsName[labelInfoOut]];
Matrix<ElemType>* labels = matrices.GetInputMatrixPtr(m_labelsName[labelInfoOut]);
if (labels == nullptr)
return;

Expand Down
12 changes: 6 additions & 6 deletions Source/Readers/LUSequenceReader/LUSequenceReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,7 @@ bool BatchLUSequenceReader<ElemType>::GetMinibatch(StreamMinibatchInputs<ElemTyp
{
if (matrices.find(m_featuresName) == matrices.end())
RuntimeError("BatchLUSequenceReader cannot find %ls.", m_featuresName.c_str());
Matrix<ElemType>& features = *matrices[m_featuresName];
Matrix<ElemType>& features = matrices.GetInputMatrix(m_featuresName);

// loop through all the samples and create a one-hot representation, or multi-hot in some conditions (TODO: which condition)
Matrix<ElemType> locObs(CPUDEVICE);
Expand Down Expand Up @@ -908,7 +908,7 @@ bool BatchLUSequenceReader<ElemType>::GetMinibatch(StreamMinibatchInputs<ElemTyp
template <class ElemType>
size_t BatchLUSequenceReader<ElemType>::GetLabelOutput(StreamMinibatchInputs<ElemType>& matrices, LabelInfo& labelInfo, size_t actualmbsize)
{
Matrix<ElemType>* labels = matrices[m_labelsName[labelInfoOut]];
Matrix<ElemType>* labels = matrices.GetInputMatrixPtr(m_labelsName[labelInfoOut]);
if (labels == nullptr)
return 0;

Expand Down Expand Up @@ -1044,7 +1044,7 @@ bool BatchLUSequenceReader<ElemType>::GetFrame(StreamMinibatchInputs<ElemType>&
const LabelInfo& featInfo = m_labelInfo[labelInfoIn];

// loop through all the samples
Matrix<ElemType>& features = *matrices[m_featuresName];
Matrix<ElemType>& features = matrices.GetInputMatrix(m_featuresName);
Matrix<ElemType> locObs(CPUDEVICE);
locObs.SwitchToMatrixType(SPARSE, matrixFormatSparseCSC, false);

Expand Down Expand Up @@ -1100,7 +1100,7 @@ bool BatchLUSequenceReader<ElemType>::GetFrame(StreamMinibatchInputs<ElemType>&
{
assert(mMatrices[p->first]->GetNumCols() > tidx);
if (matrices.find(p->first) != matrices.end())
matrices[p->first]->SetValue(mMatrices[p->first]->ColumnSlice(tidx, mRequestedNumParallelSequences));
matrices.GetInputMatrix(p->first).SetValue(mMatrices[p->first]->ColumnSlice(tidx, mRequestedNumParallelSequences));
}
}

Expand All @@ -1117,12 +1117,12 @@ void BatchLUSequenceReader<ElemType>::InitProposals(StreamMinibatchInputs<ElemTy
{
// no need to save info for labelInfoIn since it is in mProposals
if (pMat.find(m_labelsName[labelInfoOut]) != pMat.end())
mMatrices[m_labelsName[labelInfoOut]]->SetValue(*(pMat[m_labelsName[labelInfoOut]]));
mMatrices[m_labelsName[labelInfoOut]]->SetValue(pMat.GetInputMatrix(m_labelsName[labelInfoOut]));
}
else
{
if (pMat.find(m_featuresName) != pMat.end())
mMatrices[m_featuresName]->SetValue(*(pMat[m_featuresName]));
mMatrices[m_featuresName]->SetValue(pMat.GetInputMatrix(m_featuresName));
}
}

Expand Down
Loading

0 comments on commit f3a720e

Please sign in to comment.