Skip to content

Commit

Permalink
Converted Resize -> RequireSize, and separated resizing dimensions wi…
Browse files Browse the repository at this point in the history
…th allocating storage in sparse matrix.
  • Loading branch information
thhoens committed Mar 31, 2016
1 parent c02829c commit 5e481b8
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 129 deletions.
66 changes: 40 additions & 26 deletions Source/Math/CPUSparseMatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ template <class ElemType>
CPUSparseMatrix<ElemType>::CPUSparseMatrix(const MatrixFormat format, const size_t numRows, const size_t numCols, const size_t size)
{
CheckInit(format);
RequireSize(numRows, numCols, size, true, false);
RequireSizeAndAllocate(numRows, numCols, size, true, false);
}

// copy constructor, deep copy
Expand Down Expand Up @@ -209,7 +209,7 @@ void CPUSparseMatrix<ElemType>::SetValue(const size_t row, const size_t col, con
let nz = NzCount();
if (GetSizeAllocated() < nz + 1) // automatic resize
{
Resize(m_numRows, m_numCols, nz + 100, true, true); // allocate 100 more elelemnts and keep existing values
Allocate(m_numRows, m_numCols, nz + 100, true, true); // allocate 100 more elelemnts and keep existing values
}

if (row < 0 || row >= m_numRows)
Expand Down Expand Up @@ -258,7 +258,7 @@ void CPUSparseMatrix<ElemType>::SetValue(const CPUSparseMatrix<ElemType>& v)
SetFormat(v.GetFormat());
SetExternalBuffer(false);

RequireSize(v.GetNumRows(), v.GetNumCols(), v.NzSize());
RequireSizeAndAllocate(v.GetNumRows(), v.GetNumCols(), v.NzSize());
size_t nz = v.NzCount();

if (nz > 0)
Expand Down Expand Up @@ -429,7 +429,7 @@ void CPUSparseMatrix<ElemType>::SetMatrixFromCSCFormat(const CPUSPARSE_INDEX_TYP
LogicError("Cannot modify since the buffer is managed externally.");

SetFormat(matrixFormatSparseCSC);
RequireSize(numRows, numCols, nz, true, false);
RequireSizeAndAllocate(numRows, numCols, nz, true, false);

memcpy(RowLocation(), h_Row, RowSize());
memcpy(ColLocation(), h_CSCCol, ColSize());
Expand All @@ -449,33 +449,31 @@ ElemType* CPUSparseMatrix<ElemType>::Data()
}

template <class ElemType>
void CPUSparseMatrix<ElemType>::RequireSize(const size_t numRows, const size_t numCols, size_t numNZElemToReserve, const bool growOnly, bool keepExistingValues)
void CPUSparseMatrix<ElemType>::RequireSizeAndAllocate(const size_t numRows, const size_t numCols, size_t numNZElemToReserve, const bool growOnly, bool keepExistingValues)
{
numNZElemToReserve = max(numNZElemToReserve, (size_t)1);
size_t newCompIndexSize = (numCols > numRows ? numCols : numRows) + 1;
bool reallocate = (GetSizeAllocated() < numNZElemToReserve || (GetSizeAllocated() > numNZElemToReserve && !growOnly) || GetCompIndexSize() < newCompIndexSize);
RequireSize(numRows, numCols);

if (reallocate || numRows != GetNumRows() || numCols != GetNumCols())
Resize(numRows, numCols, numNZElemToReserve, growOnly, keepExistingValues);
Allocate(numRows, numCols, numNZElemToReserve, growOnly, keepExistingValues);
}

template <class ElemType>
void CPUSparseMatrix<ElemType>::Resize(const size_t numRows, const size_t numCols, size_t numNZElemToReserve, const bool growOnly, bool keepExistingValues)
void CPUSparseMatrix<ElemType>::RequireSize(const size_t numRows, const size_t numCols)
{
if (numRows != GetNumRows() || numCols != GetNumCols())
Resize(numRows, numCols);
}

if (GetNumRows() != numRows || GetNumCols() != numCols)
{
VerifyResizable(__func__);
keepExistingValues = false;
}
template <class ElemType>
void CPUSparseMatrix<ElemType>::Allocate(const size_t numRows, const size_t numCols, size_t numNZElemToReserve, const bool growOnly, bool keepExistingValues)
{

if (GetNumStorageRows() != numRows || GetNumStorageCols() != numCols)
LogicError("Allocate called with dimensions (%d, %d), but the matrix is of dimensions (%d, %d). Resize must be called first.", numRows, numCols, GetNumStorageRows(), GetNumStorageCols());

numNZElemToReserve = max(numNZElemToReserve, (size_t) 1);
size_t newCompIndexSize = (numCols > numRows ? numCols : numRows) + 1;
bool reallocate = (GetSizeAllocated() < numNZElemToReserve || (GetSizeAllocated() > numNZElemToReserve && !growOnly) || GetCompIndexSize() < newCompIndexSize);

m_numRows = numRows;
m_numCols = numCols;

if (reallocate)
{
if (GetFormat() == MatrixFormat::matrixFormatSparseCSC || GetFormat() == MatrixFormat::matrixFormatSparseCSR)
Expand All @@ -485,7 +483,7 @@ void CPUSparseMatrix<ElemType>::Resize(const size_t numRows, const size_t numCol
auto* compIndex = new CPUSPARSE_INDEX_TYPE[newCompIndexSize]();

if (keepExistingValues && (NzCount() > numNZElemToReserve || GetCompIndexSize() > newCompIndexSize))
LogicError("Resize: To keep values m_nz should <= numNZElemToReserve and m_compIndexSize <= newCompIndexSize");
LogicError("Allocate: To keep values m_nz should <= numNZElemToReserve and m_compIndexSize <= newCompIndexSize");

memset(pArray, 0, sizeof(ElemType) * numNZElemToReserve);
memset(unCompIndex, 0, sizeof(CPUSPARSE_INDEX_TYPE) * numNZElemToReserve);
Expand Down Expand Up @@ -532,8 +530,27 @@ void CPUSparseMatrix<ElemType>::Resize(const size_t numRows, const size_t numCol
SetSizeAllocated(numNZElemToReserve);
SetCompIndexSize(newCompIndexSize);
}
}


// Note: Resize will only allocate a new buffer if the dimensions change.
template <class ElemType>
void CPUSparseMatrix<ElemType>::Resize(const size_t numRows, const size_t numCols, const size_t numNZElemToReserve, const bool growOnly)
{
VerifyResizable(__func__);

size_t newCompIndexSize = (numCols > numRows ? numCols : numRows) + 1;
bool reallocate = (GetCompIndexSize() < newCompIndexSize);

m_numRows = numRows;
m_numCols = numCols;
SetNumStorageRows(numRows);
SetNumStorageCols(numCols);

if (reallocate)
Allocate(numRows, numCols, numNZElemToReserve, growOnly, false);
else
memset(GetCompIndex(), 0, sizeof(CPUSPARSE_INDEX_TYPE) * newCompIndexSize);
}

// Reset matrix to 0.
Expand Down Expand Up @@ -568,10 +585,7 @@ void CPUSparseMatrix<ElemType>::MultiplyAndWeightedAdd(ElemType alpha, const CPU
InvalidArgument("CPUSparseMatrix::MultiplyAndWeightedAdd: The inner dimensions of a and b must match.");
}

if (beta == 0)
c.RequireSize(m, n);
else
c.VerifySize(m, n); // Can't resize if beta != 0
c.RequireSize(m, n);

if (beta == 0)
{
Expand Down Expand Up @@ -674,7 +688,7 @@ void CPUSparseMatrix<ElemType>::MultiplyAndAdd(ElemType alpha, const CPUMatrix<E

// allocate enough memory
c.SetFormat(matrixFormatSparseBlockCol);
c.RequireSize(m, n, m * min(n, rhs.NzCount()), true, false);
c.RequireSizeAndAllocate(m, n, m * min(n, rhs.NzCount()), true, false);

map<size_t, size_t> w2Id;
for (size_t j = 0; j < rhs.GetNumCols(); j++)
Expand Down Expand Up @@ -1182,7 +1196,7 @@ MATH_API File& operator>>(File& stream, CPUSparseMatrix<ElemType>& us)
if (us.GetFormat() != matrixFormatSparseCSC && us.GetFormat() != matrixFormatSparseCSR)
NOT_IMPLEMENTED;

us.RequireSize(rownum, colnum, nz, true, false);
us.RequireSizeAndAllocate(rownum, colnum, nz, true, false);

if (nz > 0)
{
Expand Down
4 changes: 2 additions & 2 deletions Source/Math/CPUSparseMatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ class MATH_API CPUSparseMatrix : public BaseMatrix<ElemType>
}

void RequireSizeAndAllocate(const size_t numRows, const size_t numCols, size_t numNZElemToReserve = 10000, const bool growOnly = true, bool keepExistingValues = false);
void RequireSize(const size_t numRows, const size_t numCols, size_t numNZElemToReserve = 10000, const bool growOnly = true, bool keepExistingValues = false);
void Resize(const size_t numRows, const size_t numCols, size_t numNZElemToReserve = 10000, const bool growOnly = true, bool keepExistingValues = false);
void RequireSize(const size_t numRows, const size_t numCols);
void Resize(const size_t numRows, const size_t numCols, size_t numNZElemToReserve = 10000, const bool growOnly = true);
void Allocate(const size_t numRows, const size_t numCols, size_t numNZElemToReserve = 10000, const bool growOnly = true, bool keepExistingValues = false);
void Reset();

Expand Down
Loading

0 comments on commit 5e481b8

Please sign in to comment.