Skip to content

Commit

Permalink
added RowElementTimes and ColumnElementTimes nodes.
Browse files Browse the repository at this point in the history
Revoke back ElementTimes node to do element-wise multiplication since the implementation there for column element-wise multiplication is incorrect.
  • Loading branch information
Dong Yu committed Aug 5, 2015
1 parent a0567f6 commit d62e5db
Show file tree
Hide file tree
Showing 55 changed files with 724 additions and 969 deletions.
605 changes: 0 additions & 605 deletions CheckInSuites/SLU/Expected.log

This file was deleted.

File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
53 changes: 44 additions & 9 deletions MachineLearning/CNTK/ComputationNetwork.h
Original file line number Diff line number Diff line change
Expand Up @@ -1268,14 +1268,22 @@ class ComputationNetwork
{
newNode = new TransposeTimesNode<ElemType>(fstream, modelVersion, m_deviceId, nodeName);
}
else if (nodeType == StrideTimesNode<ElemType>::TypeName())
{
newNode = new StrideTimesNode<ElemType>(fstream, modelVersion, m_deviceId, nodeName);
}
else if (nodeType == StrideTimesNode<ElemType>::TypeName())
{
newNode = new StrideTimesNode<ElemType>(fstream, modelVersion, m_deviceId, nodeName);
}
else if (nodeType == ElementTimesNode<ElemType>::TypeName())
{
newNode = new ElementTimesNode<ElemType>(fstream, modelVersion, m_deviceId, nodeName);
}
else if (nodeType == RowElementTimesNode<ElemType>::TypeName())
{
newNode = new RowElementTimesNode<ElemType>(fstream, modelVersion, m_deviceId, nodeName);
}
else if (nodeType == ColumnElementTimesNode<ElemType>::TypeName())
{
newNode = new ColumnElementTimesNode<ElemType>(fstream, modelVersion, m_deviceId, nodeName);
}
else if (nodeType == DiagTimesNode<ElemType>::TypeName())
{
newNode = new DiagTimesNode<ElemType>(fstream, modelVersion, m_deviceId, nodeName);
Expand Down Expand Up @@ -1606,14 +1614,22 @@ class ComputationNetwork
{
newNode = new TransposeTimesNode<ElemType>(m_deviceId, nodeName);
}
else if (nodeType == StrideTimesNode<ElemType>::TypeName())
{
newNode = new StrideTimesNode<ElemType>(m_deviceId, nodeName);
}
else if (nodeType == StrideTimesNode<ElemType>::TypeName())
{
newNode = new StrideTimesNode<ElemType>(m_deviceId, nodeName);
}
else if (nodeType == ElementTimesNode<ElemType>::TypeName())
{
newNode = new ElementTimesNode<ElemType>(m_deviceId, nodeName);
}
else if (nodeType == RowElementTimesNode<ElemType>::TypeName())
{
newNode = new RowElementTimesNode<ElemType>(m_deviceId, nodeName);
}
else if (nodeType == ColumnElementTimesNode<ElemType>::TypeName())
{
newNode = new ColumnElementTimesNode<ElemType>(m_deviceId, nodeName);
}
else if (nodeType == DiagTimesNode<ElemType>::TypeName())
{
newNode = new DiagTimesNode<ElemType>(m_deviceId, nodeName);
Expand Down Expand Up @@ -2110,7 +2126,26 @@ class ComputationNetwork
return newNode;
}

ComputationNodePtr StrideTimes(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, const std::wstring nodeName = L"")
ComputationNodePtr RowElementTimes(const ComputationNodePtr a,
const ComputationNodePtr b,
const std::wstring nodeName = L"")
{
ComputationNodePtr newNode(new RowElementTimesNode<ElemType>(m_deviceId, nodeName));
newNode->AttachInputs(a, b);
AddNodeToNet(newNode);
return newNode;
}

ComputationNodePtr ColumnElementTimes(const ComputationNodePtr a,
const ComputationNodePtr b,
const std::wstring nodeName = L"")
{
ComputationNodePtr newNode(new ColumnElementTimesNode<ElemType>(m_deviceId, nodeName));
newNode->AttachInputs(a, b);
AddNodeToNet(newNode);
return newNode;
}
ComputationNodePtr StrideTimes(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, const std::wstring nodeName = L"")
{
ComputationNodePtr newNode(new StrideTimesNode<ElemType>(m_deviceId, nodeName));
newNode->AttachInputs(a, b, c);
Expand Down
Loading

0 comments on commit d62e5db

Please sign in to comment.