diff --git a/DataReader/ImageReader/ImageReader.cpp b/DataReader/ImageReader/ImageReader.cpp index 669603a28bb7..d8a7cd592b54 100644 --- a/DataReader/ImageReader/ImageReader.cpp +++ b/DataReader/ImageReader/ImageReader.cpp @@ -6,14 +6,12 @@ #include #include #include -#include - -#include +#include namespace Microsoft { namespace MSR { namespace CNTK { template -ImageReader::ImageReader() : m_seed(0) +ImageReader::ImageReader() : m_seed(0), m_rng(m_seed), m_rndUniInt(0, INT_MAX) { } @@ -42,6 +40,12 @@ void ImageReader::Init(const ConfigParameters& config) m_imgHeight = featSect.second("height"); m_imgChannels = featSect.second("channels"); m_featDim = m_imgWidth * m_imgHeight * m_imgChannels; + m_meanFile = featSect.second(L"meanFile", L""); + + m_cropType = ParseCropType(featSect.second("cropType", "")); + m_cropRatio = std::stof(featSect.second("cropRatio", "1")); + if (!(0 < m_cropRatio && m_cropRatio <= 1.0f)) + RuntimeError("Invalid cropRatio value: %f.", m_cropRatio); SectionT labSect{ gettter("labelDim") }; m_labName = msra::strfun::utf16(labSect.first); @@ -63,8 +67,7 @@ void ImageReader::Init(const ConfigParameters& config) files.push_back({ imgPath, std::stoi(clsId) }); } - std::default_random_engine rng(m_seed); - std::shuffle(files.begin(), files.end(), rng); + std::shuffle(files.begin(), files.end(), m_rng); m_epochStart = 0; m_mbStart = 0; @@ -116,20 +119,22 @@ bool ImageReader::GetMinibatch(std::map std::fill(m_labBuf.begin(), m_labBuf.end(), static_cast(0)); -#pragma omp parallel for ordered schedule(dynamic) +//#pragma omp parallel for ordered schedule(dynamic) for (long long i = 0; i < static_cast(mbLim - m_mbStart); i++) { const auto& p = files[i + m_mbStart]; auto img = cv::imread(p.first, cv::IMREAD_COLOR); // Crop - int w = img.cols; - int h = img.rows; - int cropSize = std::min(w, h); - int xOff = (w - cropSize) / 2; - int yOff = (h - cropSize) / 2; - cv::Mat cropped{ img(cv::Rect(xOff, yOff, cropSize, cropSize)) }; - cropped.convertTo(img, CV_32F); + cv::Mat cropped; + CropTransform(img, cropped); + //int w = img.cols; + //int h = img.rows; + //int cropSize = std::min(w, h); + //int xOff = (w - cropSize) / 2; + //int yOff = (h - cropSize) / 2; + //cv::Mat cropped{ img(cv::Rect(xOff, yOff, cropSize, cropSize)) }; + cropped.convertTo(img, CV_32F); // Scale cv::resize(img, img, cv::Size(static_cast(m_imgWidth), static_cast(m_imgHeight)), 0, 0, cv::INTER_LINEAR); @@ -173,6 +178,65 @@ template void ImageReader::SetRandomSeed(unsigned int seed) { m_seed = seed; + m_rng.seed(m_seed); +} + +template +typename ImageReader::CropType ImageReader::ParseCropType(const std::string& src) +{ + auto AreEqual = [](const std::string& s1, const std::string& s2) -> bool + { + return std::equal(s1.begin(), s1.end(), s2.begin(), [](const char& a, const char& b) { return std::tolower(a) == std::tolower(b); }); + }; + + if (src.empty() || AreEqual(src, "center")) + return CropType::Center; + if (AreEqual(src, "random")) + return CropType::Random; + + RuntimeError("Invalid crop type: %s.", src.c_str()); +} + +template +cv::Rect ImageReader::GetCropRect(CropType type, int crow, int ccol, float cropRatio) +{ + assert(crow > 0); + assert(ccol > 0); + assert(0 < cropRatio && cropRatio <= 1.0f); + + int cropSize = static_cast(std::min(crow, ccol) * cropRatio); + int xOff = -1; + int yOff = -1; + + switch (type) + { + case CropType::Center: + xOff = (ccol - cropSize) / 2; + yOff = (crow - cropSize) / 2; + break; + case CropType::Random: + xOff = m_rndUniInt(m_rng) % (ccol - cropSize); + yOff = m_rndUniInt(m_rng) % (crow - cropSize); + break; + default: + assert(false); + } + + assert(0 <= xOff && xOff <= ccol - cropSize); + assert(0 <= yOff && yOff <= crow - cropSize); + return cv::Rect(xOff, yOff, cropSize, cropSize); +} + +template +void ImageReader::CropTransform(const cv::Mat& src, cv::Mat& dst) +{ + // REVIEW alexeyk: optimize resizing? + dst = src(GetCropRect(m_cropType, src.rows, src.cols, m_cropRatio)).clone(); +} + +template +void ImageReader::SubMeanTransform(const cv::Mat& , cv::Mat& ) +{ } template class ImageReader; diff --git a/DataReader/ImageReader/ImageReader.h b/DataReader/ImageReader/ImageReader.h index a9c5a7af384a..dfc38ba10b9d 100644 --- a/DataReader/ImageReader/ImageReader.h +++ b/DataReader/ImageReader/ImageReader.h @@ -6,6 +6,8 @@ // ImageReader.h - Include file for the image reader #pragma once +#include +#include #include "DataReader.h" namespace Microsoft { namespace MSR { namespace CNTK { @@ -29,17 +31,23 @@ class ImageReader : public IDataReader bool DataEnd(EndDataType endDataType) override; size_t NumberSlicesInEachRecurrentIter() override { return 0; } - void SetSentenceSegBatch(Matrix &, vector&) override { }; + void SetSentenceSegBatch(Matrix&, vector&) override { }; void SetRandomSeed(unsigned int seed) override; - //virtual const std::map& GetLabelMapping(const std::wstring& sectionName); - //virtual void SetLabelMapping(const std::wstring& sectionName, const std::map& labelMapping); - //virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart = 0); - //void SetSentenceSegBatch(Matrix&, Matrix&) { }; - //void SetNbrSlicesEachRecurrentIter(const size_t sz); +private: + enum class CropType { Center = 0, Random = 1 }; + + CropType ParseCropType(const std::string& src); + cv::Rect GetCropRect(CropType type, int crow, int ccol, float cropRatio); + void CropTransform(const cv::Mat& src, cv::Mat& dst); + + void SubMeanTransform(const cv::Mat& src, cv::Mat& dst); private: + std::default_random_engine m_rng; + std::uniform_int_distribution m_rndUniInt; + std::wstring m_featName; std::wstring m_labName; @@ -62,5 +70,10 @@ class ImageReader : public IDataReader std::vector m_labBuf; unsigned int m_seed; + + CropType m_cropType; + float m_cropRatio; + + std::wstring m_meanFile; }; }}}