Skip to content

Commit

Permalink
Fixing strides and dilation attributes in ONNX AveragePool, MaxPool, …
Browse files Browse the repository at this point in the history
…Conv, ConvTranpose import.
  • Loading branch information
Spandan Tiwari committed Jun 27, 2018
1 parent e612bc4 commit 805c4a6
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions Source/CNTKv2LibraryDll/proto/onnx/ONNXToCNTK.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2021,7 +2021,7 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
else if (onnxOpName == "AveragePool" || onnxOpName == "MaxPool")
{
NDShape poolingWindowShape = GetNamedAttributeAsShape(node, "kernel_shape", false);
NDShape strides = GetNamedAttributeAsShape(node, "strides", false);
NDShape strides = GetNamedAttributeAsShape(node, "strides", false, NDShape(std::vector<size_t>(poolingWindowShape.Rank(), 1u)));

bool ceilOutDim = false;
bool includePad = false;
Expand Down Expand Up @@ -2949,9 +2949,11 @@ FunctionPtr ONNXToCNTKHelper::CreateCNTKConvTransposeNode(const Node *node, cons
{
Variable inputOperand = inputs[0];
Variable convolutionMap = inputs[1];
size_t numSpatialDim = convolutionMap.Shape().Rank() - 1; // This is conv op dimension, i.e. 2 for 2D conv, 3 for 3D conv.

NDShape strides = GetNamedAttributeAsShape(node, "strides", false, NDShape(std::vector<size_t>(numSpatialDim, 1u)));
NDShape dilation = GetNamedAttributeAsShape(node, "dilations", false, NDShape(std::vector<size_t>(numSpatialDim, 1u)));

NDShape strides = GetNamedAttributeAsShape(node, "strides", false);
NDShape dilation = GetNamedAttributeAsShape(node, "dilations", false, {1});
std::vector<bool> sharing({true});
size_t reductionRank = 1;
size_t maxTempMemSizeInSamples = 0;
Expand Down Expand Up @@ -3005,18 +3007,20 @@ FunctionPtr ONNXToCNTKHelper::CreateCNTKConvTransposeNode(const Node *node, cons

FunctionPtr ONNXToCNTKHelper::CreateCNTKConvNode(const Node *node, const std::vector<Variable> &inputs)
{
NDShape strides = GetNamedAttributeAsShape(node, "strides", false);
NDShape dilation = GetNamedAttributeAsShape(node, "dilations", false, {1});
Variable convolutionMap = inputs[1];
size_t numSpatialDim = convolutionMap.Shape().Rank() - 1; // This is conv op dimension, i.e. 2 for 2D conv, 3 for 3D conv.

NDShape strides = GetNamedAttributeAsShape(node, "strides", false, NDShape(std::vector<size_t>(numSpatialDim, 1u)));
NDShape dilation = GetNamedAttributeAsShape(node, "dilations", false, NDShape(std::vector<size_t>(numSpatialDim, 1u)));
// TODO: avoid hardcoded values
std::vector<bool> sharing({true});
size_t reductionRank = 1;
size_t maxTempMemSizeInSamples = 0;
size_t groups = GetNamedAttributeAsInt64(node, "group", 1);

Variable convolutionMap = inputs[1];

std::vector<bool> cntkConvAutoPadding;
auto convOperand = GetNodeOperandWithPaddingResolved(/*output arg first*/ cntkConvAutoPadding, strides, node, inputs);

FunctionPtr cntkConvFunction = Convolution(
convolutionMap,
convOperand,
Expand Down

0 comments on commit 805c4a6

Please sign in to comment.