Skip to content

Commit

Permalink
Require FC node to have 2D input
Browse files Browse the repository at this point in the history
  • Loading branch information
jfix71 committed Mar 29, 2019
1 parent 73bec4b commit fd7d9a6
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 42 deletions.
22 changes: 15 additions & 7 deletions include/glow/Graph/Graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -313,15 +313,20 @@ class Function final : public Named {
unsigned_t kernel, unsigned_t stride,
unsigned_t pad);

/// Creates and \returns a FullyConnectedNode with \p name, \p input, weights
/// \p W, bias \p B. If \p input is not 2 dimensional then it is flattened
/// along \p axis. Note, output type and outputDepth are inferred based on
/// the input types.
FullyConnectedNode *createFullyConnected(llvm::StringRef name,
NodeValue input, Storage *W,
Storage *B);
Storage *B, unsigned_t axis = 1);

/// Create a fully connected node with the specified output type.
/// Note, outputDepth is infered based on the output type.
/// Creates and \returns a FullyConnectedNode with \p name, \p input, weights
/// \p W, bias \p B, and \p outTy. If \p input is not 2 dimensional then it is
/// flattened along \p axis. Note, outputDepth is inferred based on \p outTy.
FullyConnectedNode *createFullyConnected(llvm::StringRef name,
NodeValue input, Node *W, Node *B,
TypeRef outTy);
TypeRef outTy, unsigned_t axis = 1);

/// Create a row-wise quantized fully connected node. This node is only used
/// in quantization. Args \p input and \p B are quantized in regular way, \p W
Expand Down Expand Up @@ -926,11 +931,14 @@ class Function final : public Named {
unsigned_t stride, unsigned_t pad,
unsigned_t group);

/// Create a fully connected node with the given \p name, \p input and \p
/// output depth. Trainable weight and bias variables are created implicitly.
/// Creates and \returns a FullyConnectedNode with \p name, \p input, weights
/// \p W, bias \p B. If \p input is not 2 dimensional then it is flattened
/// along \p axis. Note, output type is inferred based on the input
/// types. Trainable weight and bias variables are created implicitly.
FullyConnectedNode *createFullyConnected(PlaceholderBindings &bindings,
llvm::StringRef name,
NodeValue input, size_t outDepth);
NodeValue input, size_t outDepth,
unsigned_t axis = 1);

/// Create an unrolled single-layer Simple RNN cell with \p hiddenSize
/// dimensionality of the hidden state and \p outputSize dimensionality of the
Expand Down
38 changes: 23 additions & 15 deletions lib/Graph/Graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -643,20 +643,27 @@ AvgPoolNode *Function::createAvgPool(llvm::StringRef name, NodeValue input,

FullyConnectedNode *Function::createFullyConnected(llvm::StringRef name,
NodeValue input, Storage *W,
Storage *B) {
Storage *B,
unsigned_t axis) {
TypeRef T = input.getType();
TypeRef OT = getParent()->uniqueTypeWithNewShape(
T, {input.dims()[0], B->getType()->dims()[0]});

return addNode(new FullyConnectedNode(name, OT, input, W, B));
return createFullyConnected(name, input, W, B, OT, axis);
}

FullyConnectedNode *Function::createFullyConnected(llvm::StringRef name,
NodeValue input, Node *W,
Node *B, TypeRef outTy) {
Node *B, TypeRef outTy,
unsigned_t axis) {
assert(outTy->dims().size() == 2 && "Invalid number of dimensions");
assert(outTy->dims()[0] == input.dims()[0] && "Invalid dimensions");

// FC always uses 2D input; flatten if necessary.
if (input.dims().size() != 2) {
input = createFlatten(name.str() + ".reshape2D", input, axis);
}

TypeRef OT = getParent()->uniqueType(*outTy);
return addNode(new FullyConnectedNode(name, OT, input, W, B));
}
Expand Down Expand Up @@ -2077,22 +2084,23 @@ ConvertToNode *Function::createConvertTo(llvm::StringRef name, NodeValue input,
FullyConnectedNode *
Function::createFullyConnected(PlaceholderBindings &bindings,
llvm::StringRef name, NodeValue input,
size_t outDepth) {
TypeRef T = input.getType();
auto idim = flattenCdr(input.dims());
size_t fanIn = idim.second;
size_t outDepth, unsigned_t axis) {
const ElemKind k = input.getType()->getElementType();

auto *W = getParent()->createPlaceholder(
T->getElementType(), {idim.second, outDepth}, "weights", true);
auto *B = getParent()->createPlaceholder(T->getElementType(), {outDepth},
"bias", true);
// FC always uses 2D input; flatten if necessary.
if (input.dims().size() != 2) {
input = createFlatten(name.str() + ".reshape2D", input, axis);
}
auto *W = getParent()->createPlaceholder(k, {input.dims()[1], outDepth},
"weights", true);
auto *B = getParent()->createPlaceholder(k, {outDepth}, "bias", true);

bindings.allocate(W)->init(Tensor::InitKind::Xavier, fanIn, getPRNG());
bindings.allocate(W)->init(Tensor::InitKind::Xavier, input.dims()[1],
getPRNG());
bindings.allocate(B)->init(Tensor::InitKind::Broadcast, .1, getPRNG());

auto OT =
getParent()->uniqueType(T->getElementType(), {idim.first, outDepth});
return addNode(new FullyConnectedNode(name, OT, input, W, B));
auto OT = getParent()->uniqueType(k, {input.dims()[0], outDepth});
return createFullyConnected(name, input, W, B, OT, axis);
}

Node *Function::createDotProduct(llvm::StringRef name, NodeValue X,
Expand Down
12 changes: 7 additions & 5 deletions lib/Graph/Nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,14 @@ static bool verifyConvolution3D(NodeValue src, NodeValue dest, NodeValue filter,
static bool verifyFullyConnected(NodeValue src, NodeValue weights,
NodeValue bias, NodeValue dest) {
const Node *parent = dest.getNode();
bool isValid = expectCompareTrue("Mismatch on expected source dimensions",
src.dims()[0], dest.dims()[0], parent);
bool isValid = expectCompareTrue("FC input must be 2D", size_t(2),
src.dims().size(), parent);
isValid &= expectCompareTrue("FC weights must be 2D", size_t(2),
weights.dims().size(), parent);
isValid &= expectCompareTrue("Mismatch on expected source dimensions",
flattenCdr(src.dims()).second, weights.dims()[0],
parent);

src.dims()[0], dest.dims()[0], parent);
isValid &= expectCompareTrue("Mismatch on expected source dimensions",
src.dims()[1], weights.dims()[0], parent);
isValid &= expectCompareTrue("Inconsistent bias/dest sizes", bias.dims()[0],
weights.dims()[1], parent);
isValid &= expectCompareTrue("Inconsistent weights/dest sizes",
Expand Down
9 changes: 2 additions & 7 deletions lib/Importer/Caffe2ModelLoader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -568,11 +568,6 @@ llvm::Error Caffe2ModelLoader::loadOperator(const caffe2::OperatorDef &op) {
ASSIGN_VALUE_OR_RETURN_ERR(axis, loadInt(dict["axis"]));
}

// If number of input dims is greater then 2 flatten on axis.
if (in.getType()->dims().size() > 2) {
in = G_.createFlatten("fc.in", in, axis);
}

// Load weights.
Tensor *w;
ASSIGN_VALUE_OR_RETURN_ERR(w, getTensorByName(op.input(1)));
Expand Down Expand Up @@ -623,9 +618,9 @@ llvm::Error Caffe2ModelLoader::loadOperator(const caffe2::OperatorDef &op) {
auto outTy = G_.getParent()->uniqueType(
ElemKind::Int8QTy, {in.getType()->dims()[0], B->getType()->dims()[0]},
yScale, yZeroPoint - OFFSETSHIFT);
node = G_.createFullyConnected(opName, in, W, B, outTy);
node = G_.createFullyConnected(opName, in, W, B, outTy, axis);
} else {
node = G_.createFullyConnected(opName, in, W, B);
node = G_.createFullyConnected(opName, in, W, B, axis);
}

// If number of original input dims is greater than 2, expand the output
Expand Down
11 changes: 3 additions & 8 deletions lib/Optimizer/Lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,10 @@ static void lowerRegressionGradNode(Function *F, LoweredInfoMap *loweredMap,

static void lowerFullyConnectedNode(Function *F, LoweredInfoMap *loweredMap,
const FullyConnectedNode &FC) {
auto *X = F->createFlatten("fc.1X", FC.getInput(), 1);

auto W = FC.getWeights();
TypeRef outTy = F->getParent()->uniqueTypeWithNewShape(
FC.getResult().getType(), {X->getResult().dims()[0], W.dims()[1]});
auto *mul = F->createMatMul("fc.dot", outTy, X, W);

auto *add = F->createBatchedAdd("fc.add.bias", FC.getResult().getType(), mul,
FC.getBias());
TypeRef OT = FC.getResult().getType();
auto *mul = F->createMatMul("fc.dot", OT, FC.getInput(), W);
auto *add = F->createBatchedAdd("fc.add.bias", OT, mul, FC.getBias());
replaceAllUsesOfWith(loweredMap, FC.getResult(), add);

if (FC.hasPredicate()) {
Expand Down

0 comments on commit fd7d9a6

Please sign in to comment.