diff --git a/Math/Math/CPUSparseMatrix.cpp b/Math/Math/CPUSparseMatrix.cpp index 09af29220a8a..e342498c7ba5 100644 --- a/Math/Math/CPUSparseMatrix.cpp +++ b/Math/Math/CPUSparseMatrix.cpp @@ -145,6 +145,27 @@ namespace Microsoft { namespace MSR { namespace CNTK { Resize(numRows, numCols, size, true, false); } + //copy constructor, deep copy + template + CPUSparseMatrix::CPUSparseMatrix(const CPUSparseMatrix& deepCopyFrom) + { + ZeroInit(); + if (!deepCopyFrom.IsEmpty()) + SetValue(deepCopyFrom); + SetMatrixName(deepCopyFrom.m_matrixName); + } + + //assignment operator, deep copy + template + CPUSparseMatrix& CPUSparseMatrix::operator=(const CPUSparseMatrix& deepCopyFrom) + { + Clear(); + if (!deepCopyFrom.IsEmpty()) + SetValue(deepCopyFrom); + SetMatrixName(deepCopyFrom.m_matrixName); + return *this; + } + template CPUSparseMatrix::~CPUSparseMatrix() { @@ -232,16 +253,20 @@ namespace Microsoft { namespace MSR { namespace CNTK { //make sure call order in colume wise for CSC and row wise for CSR template - void CPUSparseMatrix::SetValue(const CPUSparseMatrix& v) + void CPUSparseMatrix::SetValue(const CPUSparseMatrix& v) { this->Reset(); - this->Resize(v.GetNumRows(), v.GetNumCols(), v.NzSize()); + m_format = v.GetFormat(); - memcpy(this->NzValues(), v.NzValues(), v.NzSize()); - memcpy(this->RowLocation(), v.RowLocation(), v.RowSize()); - memcpy(this->ColLocation(), v.ColLocation(), v.ColSize()); + this->Resize(v.GetNumRows(), v.GetNumCols(), v.NzSize()); + m_nz = v.NzCount(); - m_nz = v.NzCount(); + if (m_nz > 0) + { + memcpy(this->NzValues(), v.NzValues(), v.NzSize()); + memcpy(this->RowLocation(), v.RowLocation(), v.RowSize()); + memcpy(this->ColLocation(), v.ColLocation(), v.ColSize()); + } } template diff --git a/Math/Math/CPUSparseMatrix.h b/Math/Math/CPUSparseMatrix.h index e9d5baffb632..6f58fbb58c85 100644 --- a/Math/Math/CPUSparseMatrix.h +++ b/Math/Math/CPUSparseMatrix.h @@ -33,6 +33,8 @@ namespace Microsoft { namespace MSR { namespace CNTK { public: CPUSparseMatrix(const MatrixFormat format); CPUSparseMatrix(const MatrixFormat format, const size_t numRows, const size_t numCols, const size_t size); + CPUSparseMatrix(const CPUSparseMatrix& deepCopyFrom); //copy constructor, deep copy + CPUSparseMatrix& operator=(const CPUSparseMatrix& deepCopyFrom); //assignment operator, deep copy ~CPUSparseMatrix(); @@ -41,7 +43,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { using B::GetNumCols; using B::GetNumRows; void SetValue(const size_t row, const size_t col, ElemType val); - void SetValue(const CPUSparseMatrix& /*val*/); + void SetValue(const CPUSparseMatrix& /*val*/); void ShiftBy(int /*numShift*/) { NOT_IMPLEMENTED; } diff --git a/Math/Math/Matrix.cpp b/Math/Math/Matrix.cpp index 3a8da6ad1de5..042ec79684c0 100644 --- a/Math/Math/Matrix.cpp +++ b/Math/Math/Matrix.cpp @@ -457,12 +457,12 @@ namespace Microsoft { namespace MSR { namespace CNTK { this, m_CPUMatrix = new CPUMatrix(static_cast&&>(*(moveFrom.m_CPUMatrix))), m_GPUMatrix = new GPUMatrix(static_cast&&>(*(moveFrom.m_GPUMatrix))), - NOT_IMPLEMENTED, + m_CPUSparseMatrix = new CPUSparseMatrix(static_cast&&>(*(moveFrom.m_CPUSparseMatrix))), m_GPUSparseMatrix = new GPUSparseMatrix(static_cast&&>(*(moveFrom.m_GPUSparseMatrix))) ); m_preferredDeviceId = moveFrom.m_preferredDeviceId; - } + } //move assignment operator, shallow copy template @@ -479,11 +479,12 @@ namespace Microsoft { namespace MSR { namespace CNTK { if (m_GPUMatrix != nullptr) m_GPUMatrix->operator=(static_cast&&>(*(moveFrom.m_GPUMatrix))); else m_GPUMatrix = new GPUMatrix(static_cast&&>(*(moveFrom.m_GPUMatrix))), - NOT_IMPLEMENTED, + if (m_CPUSparseMatrix != nullptr) m_CPUSparseMatrix->operator=(static_cast&&>(*(moveFrom.m_CPUSparseMatrix))); + else m_CPUSparseMatrix = new CPUSparseMatrix(static_cast&&>(*(moveFrom.m_CPUSparseMatrix))), if (m_GPUSparseMatrix != nullptr) m_GPUSparseMatrix->operator=(static_cast&&>(*(moveFrom.m_GPUSparseMatrix))); else m_GPUSparseMatrix = new GPUSparseMatrix(static_cast&&>(*(moveFrom.m_GPUSparseMatrix))) - ); + ); return *this; }