Skip to content

Commit

Permalink
updated PerformSVDecomposition() to use ValueAsMatrix()
Browse files Browse the repository at this point in the history
  • Loading branch information
frankseide committed Jan 21, 2016
1 parent ce0c087 commit c9530e8
Showing 1 changed file with 9 additions and 19 deletions.
28 changes: 9 additions & 19 deletions Source/ComputationNetworkLib/ComputationNetwork.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -487,11 +487,8 @@ void ComputationNetwork::CollectInputAndLearnableParameters(const ComputationNod
list<ComputationNodeBasePtr> inputs;
for (const auto& node : nodes)
{
if (node->OperationName() == OperationNameOf(InputValue) /*L"InputValue"*/ ||
node->OperationName() == OperationNameOf(SparseInputValue) /*L"SparseInputValue"*/)
{
if (node->OperationName() == OperationNameOf(InputValue) || node->OperationName() == OperationNameOf(SparseInputValue))
inputs.push_back(node);
}
}
m_inputValues[rootNode] = inputs;

Expand All @@ -500,13 +497,10 @@ void ComputationNetwork::CollectInputAndLearnableParameters(const ComputationNod
for (auto nodeIter = nodes.begin(); nodeIter != nodes.end(); nodeIter++)
{
ComputationNodeBasePtr node = *nodeIter;
if ((node->OperationName() == OperationNameOf(LearnableParameter) && node->IsParameterUpdateRequired())
//|| (node->OperationName() == OperationNameOf(SparseLearnableParameter) && node->IsParameterUpdateRequired())
)
{
if (node->OperationName() == OperationNameOf(LearnableParameter) && node->IsParameterUpdateRequired())
learnableParameterNames.push_back(node->NodeName());
}
}

// sort names so that we get consistent order when load it from saved file
learnableParameterNames.sort();

Expand Down Expand Up @@ -944,10 +938,9 @@ void ComputationNetwork::PerformSVDecomposition(const map<wstring, float>& SVDCo
}

shared_ptr<ComputationNode<ElemType>> pNode = dynamic_pointer_cast<LearnableParameter<ElemType>>(m_nameToNodeMap[name]);
//========================================

// Step 1. do SVD decomposition
//========================================
Matrix<ElemType> A = pNode->Value();
Matrix<ElemType> A = pNode->ValueAsMatrix();

// it is a vector, no need to do it
if (A.GetNumCols() == 1 || A.GetNumRows() == 1)
Expand All @@ -966,7 +959,7 @@ void ComputationNetwork::PerformSVDecomposition(const map<wstring, float>& SVDCo
// VT \in R^{nXn}
// S \in R^{min(m,n),1}
// S is in descending order
//

ElemType totalenergy = 0.0f;
for (size_t i = 0; i < S.GetNumRows(); i++)
totalenergy += S(i, 0);
Expand Down Expand Up @@ -1019,22 +1012,19 @@ void ComputationNetwork::PerformSVDecomposition(const map<wstring, float>& SVDCo
redU.RowElementMultiplyWith(redS.Transpose());
redVT.ColumnElementMultiplyWith(redS);

//========================================

// Step 2. create two new Parameter nodes and one Times node
//========================================
wstring leftChildName = name + L"-U";
wstring rightChildName = name + L"-V";
shared_ptr<ComputationNode<ElemType>> pLeft = AddNodeToNetWithElemType(New<LearnableParameter<ElemType>>(m_deviceId, leftChildName, m, r));
shared_ptr<ComputationNode<ElemType>> pRight = AddNodeToNetWithElemType(New<LearnableParameter<ElemType>>(m_deviceId, rightChildName, r, n));

pLeft->Value() = redU;
pRight->Value() = redVT;
pLeft->ValueAsMatrix() = redU;
pRight->ValueAsMatrix() = redVT;

shared_ptr<ComputationNode<ElemType>> pTimes = AddNodeToNetAndAttachInputs(New<TimesNode<ElemType>>(m_deviceId, name + L"-SVD"), pLeft, pRight);

//========================================
// Step 3. remove old node
//========================================
ReplaceLeafNode(name, pTimes);
}
}
Expand Down

0 comments on commit c9530e8

Please sign in to comment.