Skip to content

Commit

Permalink
CPU Sparse copy constructor.
Browse files Browse the repository at this point in the history
  • Loading branch information
UnderdogGeek committed Jul 19, 2015
1 parent fa3fb05 commit cd34812
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 11 deletions.
37 changes: 31 additions & 6 deletions Math/Math/CPUSparseMatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,27 @@ namespace Microsoft { namespace MSR { namespace CNTK {
Resize(numRows, numCols, size, true, false);
}

//copy constructor, deep copy
template<class ElemType>
CPUSparseMatrix<ElemType>::CPUSparseMatrix(const CPUSparseMatrix<ElemType>& deepCopyFrom)
{
ZeroInit();
if (!deepCopyFrom.IsEmpty())
SetValue(deepCopyFrom);
SetMatrixName(deepCopyFrom.m_matrixName);
}

//assignment operator, deep copy
template<class ElemType>
CPUSparseMatrix<ElemType>& CPUSparseMatrix<ElemType>::operator=(const CPUSparseMatrix<ElemType>& deepCopyFrom)
{
Clear();
if (!deepCopyFrom.IsEmpty())
SetValue(deepCopyFrom);
SetMatrixName(deepCopyFrom.m_matrixName);
return *this;
}

template<class ElemType>
CPUSparseMatrix<ElemType>::~CPUSparseMatrix()
{
Expand Down Expand Up @@ -232,16 +253,20 @@ namespace Microsoft { namespace MSR { namespace CNTK {

//make sure call order in colume wise for CSC and row wise for CSR
template<class ElemType>
void CPUSparseMatrix<ElemType>::SetValue(const CPUSparseMatrix& v)
void CPUSparseMatrix<ElemType>::SetValue(const CPUSparseMatrix<ElemType>& v)
{
this->Reset();
this->Resize(v.GetNumRows(), v.GetNumCols(), v.NzSize());
m_format = v.GetFormat();

memcpy(this->NzValues(), v.NzValues(), v.NzSize());
memcpy(this->RowLocation(), v.RowLocation(), v.RowSize());
memcpy(this->ColLocation(), v.ColLocation(), v.ColSize());
this->Resize(v.GetNumRows(), v.GetNumCols(), v.NzSize());
m_nz = v.NzCount();

m_nz = v.NzCount();
if (m_nz > 0)
{
memcpy(this->NzValues(), v.NzValues(), v.NzSize());
memcpy(this->RowLocation(), v.RowLocation(), v.RowSize());
memcpy(this->ColLocation(), v.ColLocation(), v.ColSize());
}
}

template<class ElemType>
Expand Down
4 changes: 3 additions & 1 deletion Math/Math/CPUSparseMatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
public:
CPUSparseMatrix(const MatrixFormat format);
CPUSparseMatrix(const MatrixFormat format, const size_t numRows, const size_t numCols, const size_t size);
CPUSparseMatrix(const CPUSparseMatrix<ElemType>& deepCopyFrom); //copy constructor, deep copy
CPUSparseMatrix<ElemType>& operator=(const CPUSparseMatrix<ElemType>& deepCopyFrom); //assignment operator, deep copy


~CPUSparseMatrix();
Expand All @@ -41,7 +43,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
using B::GetNumCols; using B::GetNumRows;

void SetValue(const size_t row, const size_t col, ElemType val);
void SetValue(const CPUSparseMatrix& /*val*/);
void SetValue(const CPUSparseMatrix<ElemType>& /*val*/);

void ShiftBy(int /*numShift*/) { NOT_IMPLEMENTED; }

Expand Down
9 changes: 5 additions & 4 deletions Math/Math/Matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -457,12 +457,12 @@ namespace Microsoft { namespace MSR { namespace CNTK {
this,
m_CPUMatrix = new CPUMatrix<ElemType>(static_cast<CPUMatrix<ElemType>&&>(*(moveFrom.m_CPUMatrix))),
m_GPUMatrix = new GPUMatrix<ElemType>(static_cast<GPUMatrix<ElemType>&&>(*(moveFrom.m_GPUMatrix))),
NOT_IMPLEMENTED,
m_CPUSparseMatrix = new CPUSparseMatrix<ElemType>(static_cast<CPUSparseMatrix<ElemType>&&>(*(moveFrom.m_CPUSparseMatrix))),
m_GPUSparseMatrix = new GPUSparseMatrix<ElemType>(static_cast<GPUSparseMatrix<ElemType>&&>(*(moveFrom.m_GPUSparseMatrix)))
);

m_preferredDeviceId = moveFrom.m_preferredDeviceId;
}
}

//move assignment operator, shallow copy
template<class ElemType>
Expand All @@ -479,11 +479,12 @@ namespace Microsoft { namespace MSR { namespace CNTK {
if (m_GPUMatrix != nullptr) m_GPUMatrix->operator=(static_cast<GPUMatrix<ElemType>&&>(*(moveFrom.m_GPUMatrix)));
else m_GPUMatrix = new GPUMatrix<ElemType>(static_cast<GPUMatrix<ElemType>&&>(*(moveFrom.m_GPUMatrix))),

NOT_IMPLEMENTED,
if (m_CPUSparseMatrix != nullptr) m_CPUSparseMatrix->operator=(static_cast<CPUSparseMatrix<ElemType>&&>(*(moveFrom.m_CPUSparseMatrix)));
else m_CPUSparseMatrix = new CPUSparseMatrix<ElemType>(static_cast<CPUSparseMatrix<ElemType>&&>(*(moveFrom.m_CPUSparseMatrix))),

if (m_GPUSparseMatrix != nullptr) m_GPUSparseMatrix->operator=(static_cast<GPUSparseMatrix<ElemType>&&>(*(moveFrom.m_GPUSparseMatrix)));
else m_GPUSparseMatrix = new GPUSparseMatrix<ElemType>(static_cast<GPUSparseMatrix<ElemType>&&>(*(moveFrom.m_GPUSparseMatrix)))
);
);

return *this;
}
Expand Down

0 comments on commit cd34812

Please sign in to comment.