Skip to content

Commit

Permalink
Do not create a new packer for each epoch
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexey Reznichenko committed Apr 22, 2016
1 parent c3a0f56 commit 0651892
Show file tree
Hide file tree
Showing 14 changed files with 131 additions and 92 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ READER_SRC =\
$(SOURCEDIR)/Readers/ReaderLib/ChunkRandomizer.cpp \
$(SOURCEDIR)/Readers/ReaderLib/SequenceRandomizer.cpp \
$(SOURCEDIR)/Readers/ReaderLib/SequencePacker.cpp \
$(SOURCEDIR)/Readers/ReaderLib/BpttPacker.cpp \
$(SOURCEDIR)/Readers/ReaderLib/TruncatedBpttPacker.cpp \
$(SOURCEDIR)/Readers/ReaderLib/PackerBase.cpp \
$(SOURCEDIR)/Readers/ReaderLib/FramePacker.cpp \
Expand Down
17 changes: 9 additions & 8 deletions Source/Readers/CNTKTextFormatReader/CNTKTextFormatReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ CNTKTextFormatReader::CNTKTextFormatReader(MemoryProviderPtr provider,
randomizer->Initialize(nullptr, config);

m_transformer = randomizer;

// TODO: add "frameMode" config paramter
m_packer = std::make_shared<SequencePacker>(
m_provider,
m_transformer,
GetStreamDescriptions());
}
catch (const std::runtime_error& e)
{
Expand All @@ -61,18 +67,13 @@ std::vector<StreamDescriptionPtr> CNTKTextFormatReader::GetStreamDescriptions()

void CNTKTextFormatReader::StartEpoch(const EpochConfiguration& config)
{
if (config.m_totalEpochSizeInSamples <= 0)
if (config.m_totalEpochSizeInSamples == 0)
{
RuntimeError("Unsupported minibatch size '%d'.", (int)config.m_totalEpochSizeInSamples);
RuntimeError("Epoch size cannot be 0.");
}

m_transformer->StartEpoch(config);
// TODO: add "frameMode" config paramter
m_packer = std::make_shared<SequencePacker>(
m_provider,
m_transformer,
config.m_minibatchSizeInSamples,
GetStreamDescriptions());
m_packer->StartEpoch(config);
}

Minibatch CNTKTextFormatReader::ReadMinibatch()
Expand Down
78 changes: 41 additions & 37 deletions Source/Readers/ExperimentalHTKMLFReader/HTKMLFReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include "StringUtil.h"
#include "FramePacker.h"
#include "SequencePacker.h"
#include "BpttPacker.h"
#include "TruncatedBpttPacker.h"
#include "BlockRandomizer.h"
#include "NoRandomizer.h"

Expand Down Expand Up @@ -136,22 +136,6 @@ HTKMLFReader::HTKMLFReader(MemoryProviderPtr provider,
m_streams.push_back(stream);
}
}
}

std::vector<StreamDescriptionPtr> HTKMLFReader::GetStreamDescriptions()
{
assert(!m_streams.empty());
return m_streams;
}

void HTKMLFReader::StartEpoch(const EpochConfiguration& config)
{
if (config.m_totalEpochSizeInSamples <= 0)
{
RuntimeError("Unsupported minibatch size '%d'.", (int)config.m_totalEpochSizeInSamples);
}

m_randomizer->StartEpoch(config);

// TODO: should we unify sample and sequence mode packers into a single one.
// TODO: functionally they are the same, the only difference is how we handle
Expand All @@ -164,20 +148,35 @@ void HTKMLFReader::StartEpoch(const EpochConfiguration& config)
switch (m_packingMode)
{
case PackingMode::sample:
m_packer = std::make_shared<FramePacker>(
m_provider,
m_randomizer,
config.m_minibatchSizeInSamples,
m_streams);
m_packer = std::make_shared<FramePacker>(m_provider, m_randomizer, m_streams);
break;
case PackingMode::sequence:
m_packer = std::make_shared<SequencePacker>(
m_provider,
m_randomizer,
config.m_minibatchSizeInSamples,
m_streams);
m_packer = std::make_shared<SequencePacker>(m_provider, m_randomizer, m_streams);
break;
case PackingMode::truncated:
m_packer = std::make_shared<TruncatedBPTTPacker>(m_provider, m_randomizer, m_streams);
break;
default:
LogicError("Unsupported type of packer '%d'.", (int)m_packingMode);
}
}

std::vector<StreamDescriptionPtr> HTKMLFReader::GetStreamDescriptions()
{
assert(!m_streams.empty());
return m_streams;
}

void HTKMLFReader::StartEpoch(const EpochConfiguration& config)
{
if (config.m_totalEpochSizeInSamples == 0)
{
RuntimeError("Epoch size cannot be 0.");
}



if (m_packingMode == PackingMode::truncated)
{
size_t minibatchSize = config.m_minibatchSizeInSamples;
size_t truncationLength = m_truncationLength;
Expand All @@ -191,17 +190,22 @@ void HTKMLFReader::StartEpoch(const EpochConfiguration& config)
size_t numParallelSequences = m_numParallelSequencesForAllEpochs[config.m_epochIndex];
minibatchSize = numParallelSequences * truncationLength;
}

m_packer = std::make_shared<BpttPacker>(
m_provider,
m_randomizer,
minibatchSize,
truncationLength,
m_streams);
break;

EpochConfiguration bpttConfig;
bpttConfig.m_numberOfWorkers = config.m_numberOfWorkers;
bpttConfig.m_workerRank = config.m_workerRank;
bpttConfig.m_totalEpochSizeInSamples = config.m_totalEpochSizeInSamples;
bpttConfig.m_epochIndex = config.m_epochIndex;
bpttConfig.m_minibatchSizeInSamples = minibatchSize;
bpttConfig.m_truncationSize = truncationLength;

m_randomizer->StartEpoch(bpttConfig);
m_packer->StartEpoch(bpttConfig);
}
default:
LogicError("Unsupported type of packer '%d'.", (int)m_packingMode);
else
{
m_randomizer->StartEpoch(config);
m_packer->StartEpoch(config);
}
}

Expand Down
15 changes: 8 additions & 7 deletions Source/Readers/ImageReader/ImageReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ ImageReader::ImageReader(MemoryProviderPtr provider,
}

m_transformer = last;

m_packer = std::make_shared<FramePacker>(
m_provider,
m_transformer,
m_streams);
}

std::vector<StreamDescriptionPtr> ImageReader::GetStreamDescriptions()
Expand All @@ -80,17 +85,13 @@ std::vector<StreamDescriptionPtr> ImageReader::GetStreamDescriptions()

void ImageReader::StartEpoch(const EpochConfiguration& config)
{
if (config.m_totalEpochSizeInSamples <= 0)
if (config.m_totalEpochSizeInSamples == 0)
{
RuntimeError("Unsupported minibatch size '%u'.", (int)config.m_totalEpochSizeInSamples);
RuntimeError("Epoch size cannot be 0.");
}

m_transformer->StartEpoch(config);
m_packer = std::make_shared<FramePacker>(
m_provider,
m_transformer,
config.m_minibatchSizeInSamples,
m_streams);
m_packer->StartEpoch(config);
}

Minibatch ImageReader::ReadMinibatch()
Expand Down
3 changes: 1 addition & 2 deletions Source/Readers/ReaderLib/FramePacker.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@ class FramePacker : public SequencePacker
FramePacker(
MemoryProviderPtr memoryProvider,
TransformerPtr transformer,
size_t minibatchSize,
const std::vector<StreamDescriptionPtr>& streams) :
SequencePacker(memoryProvider, transformer, minibatchSize, streams)
SequencePacker(memoryProvider, transformer, streams)
{

}
Expand Down
3 changes: 3 additions & 0 deletions Source/Readers/ReaderLib/Packer.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ namespace Microsoft { namespace MSR { namespace CNTK {
class Packer
{
public:
// Sets current epoch configuration.
virtual void StartEpoch(const EpochConfiguration& config) = 0;

virtual Minibatch ReadMinibatch() = 0;
virtual ~Packer() {}
};
Expand Down
17 changes: 10 additions & 7 deletions Source/Readers/ReaderLib/PackerBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,26 @@ void PackerBase::StreamBuffer::Resize(size_t newSize)
});
}

void PackerBase::StartEpoch(const EpochConfiguration& config)
{
m_minibatchSize = config.m_minibatchSizeInSamples;
if (m_minibatchSize == 0)
{
LogicError("Minibatch size cannot be zero.");
}
}

PackerBase::PackerBase(MemoryProviderPtr memoryProvider,
TransformerPtr transformer,
size_t minibatchSize,
const std::vector<StreamDescriptionPtr>& streams) :
m_transformer(transformer),
m_minibatchSize(minibatchSize),
m_minibatchSize(0),
m_outputStreamDescriptions(streams)
{
m_inputStreamDescriptions = m_transformer->GetStreamDescriptions();
assert(m_inputStreamDescriptions.size() != 0);
assert(m_inputStreamDescriptions.size() == m_outputStreamDescriptions.size());

if (m_minibatchSize == 0)
{
LogicError("Minibatch size cannot be zero.");
}

m_streamBuffers.reserve(m_outputStreamDescriptions.size());

// Sanity checks:
Expand Down
5 changes: 4 additions & 1 deletion Source/Readers/ReaderLib/PackerBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class PackerBase : public Packer

PackerBase(MemoryProviderPtr memoryProvider,
TransformerPtr transformer,
size_t minibatchSize,
const std::vector<StreamDescriptionPtr>& streams);

typedef std::vector<SequenceDataPtr> StreamBatch;
Expand Down Expand Up @@ -71,6 +70,10 @@ class PackerBase : public Packer

// Minibatch size in samples.
size_t m_minibatchSize;

public:
// Sets current epoch configuration.
virtual void StartEpoch(const EpochConfiguration& config) override;
};

inline void PackerBase::PackSparseSampleAsDense(char* destination, SparseSequenceDataPtr sequence,
Expand Down
1 change: 1 addition & 0 deletions Source/Readers/ReaderLib/Reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ struct EpochConfiguration
size_t m_minibatchSizeInSamples; // Maximum minibatch size for the epoch in samples
size_t m_totalEpochSizeInSamples; // Total size of the epoch in samples
size_t m_epochIndex; // Current epoch index [0 .. max number of epochs)
size_t m_truncationSize; // Truncation size in samples for truncated BPTT mode.
};

// Supported primitive element types, will be extended in the future.
Expand Down
4 changes: 2 additions & 2 deletions Source/Readers/ReaderLib/ReaderLib.vcxproj
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@
</Link>
</ItemDefinitionGroup>
<ItemGroup>
<ClInclude Include="BpttPacker.h" />
<ClInclude Include="Bundler.h" />
<ClInclude Include="ChunkRandomizer.h" />
<ClInclude Include="DataDeserializerBase.h" />
Expand All @@ -98,9 +97,9 @@
<ClInclude Include="Reader.h" />
<ClInclude Include="ReaderShim.h" />
<ClInclude Include="Transformer.h" />
<ClInclude Include="TruncatedBpttPacker.h" />
</ItemGroup>
<ItemGroup>
<ClCompile Include="BpttPacker.cpp" />
<ClCompile Include="Bundler.cpp" />
<ClCompile Include="ChunkRandomizer.cpp" />
<ClCompile Include="NoRandomizer.cpp" />
Expand All @@ -110,6 +109,7 @@
<ClCompile Include="ReaderShim.cpp" />
<ClCompile Include="SequencePacker.cpp" />
<ClCompile Include="SequenceRandomizer.cpp" />
<ClCompile Include="TruncatedBpttPacker.cpp" />
</ItemGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets">
Expand Down
4 changes: 2 additions & 2 deletions Source/Readers/ReaderLib/ReaderLib.vcxproj.filters
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
<ClInclude Include="FramePacker.h">
<Filter>Packers</Filter>
</ClInclude>
<ClInclude Include="BpttPacker.h">
<ClInclude Include="TruncatedBpttPacker.h">
<Filter>Packers</Filter>
</ClInclude>
</ItemGroup>
Expand Down Expand Up @@ -93,7 +93,7 @@
<ClCompile Include="FramePacker.cpp">
<Filter>Packers</Filter>
</ClCompile>
<ClCompile Include="BpttPacker.cpp">
<ClCompile Include="TruncatedBpttPacker.cpp">
<Filter>Packers</Filter>
</ClCompile>
</ItemGroup>
Expand Down
3 changes: 1 addition & 2 deletions Source/Readers/ReaderLib/SequencePacker.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@ class SequencePacker : public PackerBase
SequencePacker(
MemoryProviderPtr memoryProvider,
TransformerPtr transformer,
size_t minibatchSize,
const std::vector<StreamDescriptionPtr>& streams) :
PackerBase(memoryProvider, transformer, minibatchSize, streams)
PackerBase(memoryProvider, transformer, streams)
{

}
Expand Down
Loading

0 comments on commit 0651892

Please sign in to comment.