Skip to content

Commit

Permalink
Addressed CR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
amitaga committed Mar 4, 2016
1 parent a78f198 commit 9078a9f
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 6 deletions.
9 changes: 7 additions & 2 deletions Source/ComputationNetworkLib/ReshapingNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -615,18 +615,23 @@ class DiagonalNode : public ComputationNodeNonLooping<ElemType>, public NumInput

virtual void /*ComputationNodeNonLooping::*/ BackpropToNonLooping(size_t /*inputIndex*/) override
{
#if 1
NOT_IMPLEMENTED;
#else
// The Implementation below is currently broken
auto& inputGradientValues = Input(0)->GradientAsMatrix();
auto& gradientValues = GradientAsMatrix();
// BUGBUG: This should use the memshare mechanism.
// TODO: use tensor lib, then this will be easy, no memsharing needed
Matrix<ElemType> diag(gradientValues.GetNumRows(), gradientValues.GetNumCols(), gradientValues.GetDeviceId());
diag.SetValue(gradientValues);
Matrix<ElemType> diag = gradientValues.DeepClone();
// BUGBUG: Resize does not preserve data - should be a reinterpret operation
diag.Resize(gradientValues.GetNumCols(), 1);
inputGradientValues.SetValue(0);
// BUGBUG: Must *add* to gradient!
inputGradientValues.SetDiagonalValue(diag);
#endif
}

virtual bool OutputUsedInComputingInputNodesGradients() const override { return false; }
Expand Down
4 changes: 2 additions & 2 deletions Source/Math/Matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,11 +368,11 @@ Matrix<ElemType>::Matrix(const size_t numRows, const size_t numCols, ElemType* p
m_baseMatrix->SetOwnBuffer(false);
}

//copy constructor, deep copy
// copy constructor, deep copy
template <class ElemType>
Matrix<ElemType> Matrix<ElemType>::DeepClone() const
{
return Matrix<ElemType>(*this, this->GetDeviceId());
return Matrix<ElemType>(*this, GetDeviceId());
}

template <class ElemType>
Expand Down
7 changes: 6 additions & 1 deletion Source/Math/Matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ class MATH_API Matrix : public MatrixBase

Matrix<ElemType> DeepClone() const;

// Disallow deep copy construction and assignment
// Disallow deep copy construction and assignment to avoid
// inadvertent silent deep copying
Matrix(const Matrix<ElemType>& deepCopyFrom) = delete;
Matrix<ElemType>& operator=(const Matrix<ElemType>& deepCopyFrom) = delete;

Expand Down Expand Up @@ -203,6 +204,7 @@ class MATH_API Matrix : public MatrixBase
m_baseMatrix->VerifySize(rows, cols);
}

// TODO: Call this ShallowClone instead?
Matrix<ElemType> AsReference() const
{
return ColumnSlice(0, GetNumCols());
Expand Down Expand Up @@ -303,6 +305,9 @@ class MATH_API Matrix : public MatrixBase
Matrix<ElemType> operator^(ElemType alpha) const; // element-wise power
Matrix<ElemType>& AssignElementPowerOf(const Matrix<ElemType>& a, const ElemType power);

// TODO: There are several functions below that perform an in-place operation
// We should prepend the names of these functions with InPlace for clearly indicating
// the semantics for callers.
Matrix<ElemType>& ElementMultiplyWith(const Matrix<ElemType>& a);
Matrix<ElemType>& AssignElementProductOf(const Matrix<ElemType>& a, const Matrix<ElemType>& b);
Matrix<ElemType>& AddElementProductOf(const Matrix<ElemType>& a, const Matrix<ElemType>& b);
Expand Down
2 changes: 1 addition & 1 deletion Source/SGDLib/DataReaderHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
{
wstring nodeName = node->GetName();
shared_ptr<ComputationNode<ElemType>> pLearnableNode = node;
auto& funvalue = pLearnableNode->Value(); // gradient may not be allocated when this function is first called
const auto& funvalue = pLearnableNode->Value(); // gradient may not be allocated when this function is first called
size_t nrow = funvalue.GetNumRows();
size_t ncol = funvalue.GetNumCols();
if (m_cachedGradient.find(nodeName) == m_cachedGradient.end())
Expand Down

0 comments on commit 9078a9f

Please sign in to comment.