Skip to content

Commit

Permalink
Fix KaldiReader for the new interface (SetValue).
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhang87 committed Oct 28, 2015
1 parent 8228164 commit 6619421
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
16 changes: 13 additions & 3 deletions DataReader/Kaldi2Reader/HTKMLFReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1106,6 +1106,14 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// We initialize the sentence boundary information before we process
// the utterances.
m_pMBLayout->Init(m_numberOfuttsPerMinibatch, m_currentMBSize, !m_framemode);
for (size_t i = 0; i < m_numberOfuttsPerMinibatch; i++)
{
for (size_t j = 0; j < m_currentMBSize; j++)
{
m_pMBLayout->SetWithoutOr(i, j, MinibatchPackingFlags::None);
}
}

// Iterates over utterances. m_numberOfuttsPerMinibatch = 1 is a
// special case.
for (size_t i = 0; i < m_numberOfuttsPerMinibatch; i++)
Expand Down Expand Up @@ -1412,6 +1420,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
assert(id < m_minibatchBuffer[index].features.size());
data.SetValue(dim,
m_minibatchBuffer[index].features[id].size() / dim,
data.GetDeviceId(),
m_minibatchBuffer[index].features[id].data(),
matrixFlagNormal);
}
Expand All @@ -1422,6 +1431,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
assert(id < m_minibatchBuffer[index].labels.size());
data.SetValue(dim,
m_minibatchBuffer[index].labels[id].size() / dim,
data.GetDeviceId(),
m_minibatchBuffer[index].labels[id].data(),
matrixFlagNormal);
}
Expand Down Expand Up @@ -1488,14 +1498,14 @@ namespace Microsoft { namespace MSR { namespace CNTK {
size_t id = m_featureNameToIdMap.at(iter->first);
size_t dim = m_featureNameToDimMap.at(iter->first);
assert(id < featureBuffer.size());
data.SetValue(dim, size, featureBuffer[id] , matrixFlagNormal);
data.SetValue(dim, size, data.GetDeviceId(), featureBuffer[id] , matrixFlagNormal);
}
else if (m_nameToTypeMap.at(iter->first) == InputOutputTypes::category)
{
size_t id = m_labelNameToIdMap.at(iter->first);
size_t dim = m_labelNameToDimMap.at(iter->first);
assert(id < labelBuffer.size());
data.SetValue(dim, size, labelBuffer[id], matrixFlagNormal);
data.SetValue(dim, size, data.GetDeviceId(), labelBuffer[id], matrixFlagNormal);
}
else if (m_doMinibatchBuffering)
{
Expand Down Expand Up @@ -1674,7 +1684,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
}
}
}
data.SetValue(feat.rows(), feat.cols(), m_featuresBufferMultiIO[id],matrixFlagNormal);
data.SetValue(feat.rows(), feat.cols(), data.GetDeviceId(), m_featuresBufferMultiIO[id],matrixFlagNormal);
}
else
{ // Resizes other inputs so they won't affect actual minibatch size.
Expand Down
10 changes: 5 additions & 5 deletions DataReader/KaldiReader/HTKMLFReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
}
}
}
data.SetValue(feat.rows(), feat.cols(), m_featuresBufferMultiIO[id],matrixFlagNormal);
data.SetValue(feat.rows(), feat.cols(), data.GetDeviceId(), m_featuresBufferMultiIO[id],matrixFlagNormal);
}
}
else if (m_nameToTypeMap[iter->first] == InputOutputTypes::category)
Expand Down Expand Up @@ -919,7 +919,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
}


data.SetValue(dim,uids.size(),m_labelsBufferMultiIO[id],matrixFlagNormal);
data.SetValue(dim,uids.size(),data.GetDeviceId(), m_labelsBufferMultiIO[id],matrixFlagNormal);
}

}
Expand Down Expand Up @@ -1190,13 +1190,13 @@ namespace Microsoft { namespace MSR { namespace CNTK {
{
id = m_featureNameToIdMap[iter->first];
dim = m_featureNameToDimMap[iter->first];
data.SetValue(dim, m_mbSize*m_numberOfuttsPerMinibatch, m_featuresBufferMultiIO[id],matrixFlagNormal);
data.SetValue(dim, m_mbSize*m_numberOfuttsPerMinibatch, data.GetDeviceId(), m_featuresBufferMultiIO[id],matrixFlagNormal);
}
else if (m_nameToTypeMap[iter->first] == InputOutputTypes::category)
{
id = m_labelNameToIdMap[iter->first];
dim = m_labelNameToDimMap[iter->first];
data.SetValue(dim, m_mbSize*m_numberOfuttsPerMinibatch, m_labelsBufferMultiIO[id],matrixFlagNormal);
data.SetValue(dim, m_mbSize*m_numberOfuttsPerMinibatch, data.GetDeviceId(), m_labelsBufferMultiIO[id],matrixFlagNormal);
}
}
skip=false;
Expand Down Expand Up @@ -1317,7 +1317,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
}
}
}
data.SetValue(feat.rows(), feat.cols(), m_featuresBufferMultiIO[id],matrixFlagNormal);
data.SetValue(feat.rows(), feat.cols(), data.GetDeviceId(), m_featuresBufferMultiIO[id],matrixFlagNormal);
}
}
return true;
Expand Down

0 comments on commit 6619421

Please sign in to comment.