Skip to content

Commit

Permalink
Some refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
eldakms committed May 18, 2016
1 parent 0a09d41 commit 5469f44
Show file tree
Hide file tree
Showing 8 changed files with 39 additions and 37 deletions.
14 changes: 10 additions & 4 deletions Source/Readers/CNTKTextFormatReader/Exports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "ReaderShim.h"
#include "CNTKTextFormatReader.h"
#include "HeapMemoryProvider.h"
#include "StringUtil.h"

namespace Microsoft { namespace MSR { namespace CNTK {

Expand All @@ -32,20 +33,25 @@ extern "C" DATAREADER_API void GetReaderD(IDataReader** preader)
}

// TODO: Not safe from the ABI perspective. Will be uglified to make the interface ABI.
// A factory method for creating image deserializers.
// A factory method for creating text deserializers.
extern "C" DATAREADER_API bool CreateDeserializer(IDataDeserializer** deserializer, const std::wstring& type, const ConfigParameters& deserializerConfig, CorpusDescriptorPtr corpus, bool)
{
string precision = deserializerConfig.Find("precision", "float");
if (!AreEqualIgnoreCase(precision, "float") && !AreEqualIgnoreCase(precision, "double"))
{
InvalidArgument("Unsupported precision '%s'", precision.c_str());
}

// TODO: Remove type from the parser. Current implementation does not support streams of different types.
if (type == L"CNTKTextFormatDeserializer")
{
if (precision == "float")
*deserializer = new TextParser<float>(corpus, TextConfigHelper(deserializerConfig));
else // Currently assume double, TODO: should change when support more types.
else // double
*deserializer = new TextParser<double>(corpus, TextConfigHelper(deserializerConfig));
}
else
// Unknown type.
return false;
InvalidArgument("Unknown deserializer type '%ls'", type.c_str());

// Deserializer created.
return true;
Expand Down
16 changes: 7 additions & 9 deletions Source/Readers/CNTKTextFormatReader/Indexer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ void Indexer::Build(CorpusDescriptorPtr corpus)
sd.m_fileOffsetBytes = offset;
sd.m_isValid = true;

auto& stringRegistry = corpus->GetStringRegistry();
while (!m_done)
{
SkipLine(); // ignore whatever is left on this line.
Expand All @@ -175,13 +174,7 @@ void Indexer::Build(CorpusDescriptorPtr corpus)
{
// found a new sequence, which starts at the [offset] bytes into the file
sd.m_byteSize = offset - sd.m_fileOffsetBytes;
auto key = msra::strfun::utf16(std::to_string(sd.m_id));
if (corpus->IsIncluded(key))
{
sd.m_key.m_sequence = stringRegistry[key];
sd.m_key.m_sample = 0;
AddSequence(sd);
}
AddSequenceIfIncluded(corpus, sd);

sd = {};
sd.m_id = id;
Expand All @@ -192,6 +185,12 @@ void Indexer::Build(CorpusDescriptorPtr corpus)

// calculate the byte size for the last sequence
sd.m_byteSize = m_fileOffsetEnd - sd.m_fileOffsetBytes;
AddSequenceIfIncluded(corpus, sd);
}

void Indexer::AddSequenceIfIncluded(CorpusDescriptorPtr corpus, SequenceDescriptor& sd)
{
auto& stringRegistry = corpus->GetStringRegistry();
auto key = msra::strfun::utf16(std::to_string(sd.m_id));
if (corpus->IsIncluded(key))
{
Expand All @@ -201,7 +200,6 @@ void Indexer::Build(CorpusDescriptorPtr corpus)
}
}


void Indexer::SkipLine()
{
while (!m_done)
Expand Down
3 changes: 3 additions & 0 deletions Source/Readers/CNTKTextFormatReader/Indexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ class Indexer
// (except when a sequence size is greater than the maximum chunk size)
void AddSequence(SequenceDescriptor& sd);

// Same function as above but with check that the sequence is included in the corpus descriptor.
void AddSequenceIfIncluded(CorpusDescriptorPtr corpus, SequenceDescriptor& sd);

// fills up the buffer with data from file, all previously buffered data
// will be overwritten.
void RefillBuffer();
Expand Down
16 changes: 3 additions & 13 deletions Source/Readers/ExperimentalHTKMLFReader/HTKDataDeserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,22 +127,12 @@ void HTKDataDeserializer::InitializeChunkDescriptions(ConfigHelper& config)
}

wstring key = description.GetKey();
size_t id = 0;
if (m_primary)
if (!m_corpus->IsIncluded(key))
{
// TODO: Definition of the corpus should be moved to the CorpusDescriptor
// TODO: All keys should be added there. Currently, we add them in the driving deserializer.
id = stringRegistry.AddValue(key);
}
else
{
if (!stringRegistry.TryGet(key, id))
{
// Utterance is unknown, skipping it.
continue;
}
continue;
}

size_t id = stringRegistry[key];
description.SetId(id);
utterances.push_back(description);
m_totalNumberOfFrames += numberOfFrames;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ void MLFDataDeserializer::InitializeChunkDescriptions(CorpusDescriptorPtr corpus
description.m_isValid = true;
size_t totalFrames = 0;

auto& stringRegistry = corpus->GetStringRegistry();
const auto& stringRegistry = corpus->GetStringRegistry();

// TODO resize m_keyToSequence with number of IDs from string registry

Expand Down
14 changes: 8 additions & 6 deletions Source/Readers/ImageReader/ImageDataDeserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,11 +246,12 @@ void ImageDataDeserializer::CreateSequenceDescriptions(CorpusDescriptorPtr corpu
description.m_numberOfSamples = 1;
description.m_isValid = true;

WStringToIdMap& corpusStringRegistry = corpus->GetStringRegistry();
auto& stringRegistry = corpus->GetStringRegistry();
for (size_t lineIndex = 0; std::getline(mapFile, line); ++lineIndex)
{
std::stringstream ss(line);
std::string imagePath, classId, sequenceKey;
// Try to parse sequence id, file path and label.
if (!std::getline(ss, sequenceKey, '\t') || !std::getline(ss, imagePath, '\t') || !std::getline(ss, classId, '\t'))
{
// In case when the sequence key is not specified we set it to the line number inside the mapping file.
Expand All @@ -261,9 +262,9 @@ void ImageDataDeserializer::CreateSequenceDescriptions(CorpusDescriptorPtr corpu
RuntimeError("Invalid map file format, must contain 2 or 3 tab-delimited columns, line %" PRIu64 " in file %s.", lineIndex, mapPath.c_str());
}

auto wsequenceKey = msra::strfun::utf16(sequenceKey);
// Skipping sequences that are not included in corpus.
if (!corpus->IsIncluded(wsequenceKey))
auto key = msra::strfun::utf16(sequenceKey);
if (!corpus->IsIncluded(key))
{
continue;
}
Expand All @@ -287,7 +288,7 @@ void ImageDataDeserializer::CreateSequenceDescriptions(CorpusDescriptorPtr corpu
description.m_chunkId = curId;
description.m_path = imagePath;
description.m_classId = cid;
description.m_key.m_sequence = corpusStringRegistry[wsequenceKey];
description.m_key.m_sequence = stringRegistry[key];
description.m_key.m_sample = 0;

m_keyToSequence[description.m_key.m_sequence] = m_imageSequences.size();
Expand Down Expand Up @@ -357,14 +358,15 @@ cv::Mat FileByteReader::Read(size_t, const std::string& path, bool grayscale)
return cv::imread(path, grayscale ? cv::IMREAD_GRAYSCALE : cv::IMREAD_COLOR);
}

static SequenceDescription s_InvalidSequence{0, 0, 0, false};
static SequenceDescription s_invalidSequence{0, 0, 0, false};

void ImageDataDeserializer::GetSequenceDescriptionByKey(const KeyType& key, SequenceDescription& result)
{
auto index = m_keyToSequence.find(key.m_sequence);
// Checks whether it is a known sequence for us.
if (key.m_sample != 0 || index == m_keyToSequence.end())
{
result = s_InvalidSequence;
result = s_invalidSequence;
return;
}

Expand Down
9 changes: 6 additions & 3 deletions Source/Readers/ImageReader/ImageDataDeserializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@ class ImageDataDeserializer : public DataDeserializerBase
explicit ImageDataDeserializer(const ConfigParameters& config);

// Gets sequences by specified ids. Order of returned sequences corresponds to the order of provided ids.
ChunkPtr GetChunk(size_t chunkId) override;
virtual ChunkPtr GetChunk(size_t chunkId) override;

// Gets chunk descriptions.
ChunkDescriptions GetChunkDescriptions() override;
virtual ChunkDescriptions GetChunkDescriptions() override;

// Gets sequence descriptions for the chunk.
void GetSequencesForChunk(size_t, std::vector<SequenceDescription>&) override;
virtual void GetSequencesForChunk(size_t, std::vector<SequenceDescription>&) override;

// Gets sequence description by key.
void GetSequenceDescriptionByKey(const KeyType&, SequenceDescription&) override;

private:
Expand All @@ -59,6 +60,8 @@ class ImageDataDeserializer : public DataDeserializerBase

// Sequence descriptions for all input data.
std::vector<ImageSequenceDescription> m_imageSequences;

// Mapping of logical sequence key into sequence description.
std::map<size_t, size_t> m_keyToSequence;

// Element type of the feature/label stream (currently float/double only).
Expand Down
2 changes: 1 addition & 1 deletion Source/Readers/ReaderLib/StringToIdMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class TStringToIdMap
}

// Tries to get a value by id.
bool TryGet(const TString& value, size_t& id)
bool TryGet(const TString& value, size_t& id) const
{
const auto& it = m_values.find(value);
if (it == m_values.end())
Expand Down

0 comments on commit 5469f44

Please sign in to comment.