Skip to content

Commit

Permalink
Add support for ChannelShuffle node.
Browse files Browse the repository at this point in the history
  • Loading branch information
vuzelac-cadence authored and jfix71 committed Apr 18, 2019
1 parent 9d8dae9 commit d392571
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 127 deletions.
22 changes: 2 additions & 20 deletions lib/Graph/Graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1000,26 +1000,8 @@ SliceNode *Function::createSlice(llvm::StringRef name, NodeValue input,

Node *Function::createChannelShuffle(llvm::StringRef name, NodeValue input,
size_t group, size_t kernel) {
auto inDims = input.dims();
assert(kernel < inDims.size());

ShapeVector dims(inDims.begin(), inDims.end());
auto D = dims[kernel];
assert(D % group == 0);

dims.erase(dims.begin() + kernel);
// Reshape {D1, ... D_k, ... D_n} -> {D1, ... group, D_k / group, ... D_n}
dims.insert(dims.begin() + kernel, D / group);
dims.insert(dims.begin() + kernel, group);
Node *R1 = createReshape(name.str() + ".reshape1", input, dims);

std::vector<unsigned_t> transpose(dims.size());
for (size_t i = 0; i < transpose.size(); i++)
transpose[i] = i;
std::swap(transpose[kernel], transpose[kernel + 1]);
Node *T = createTranspose(name.str() + ".transpose", R1, transpose);

return createReshape(name.str() + ".reshape2", T, inDims);
return addNode(
new ChannelShuffleNode(name, input.getType(), input, group, kernel));
}

ReshapeNode *Function::createSqueeze(llvm::StringRef name, NodeValue input,
Expand Down
8 changes: 8 additions & 0 deletions lib/Graph/Nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,14 @@ bool TransposeNode::verify() const {
return isValid;
}

bool ChannelShuffleNode::verify() const {
bool isValid = expectCompareTrue("Channel shuffle into a different size.",
getResult().getType()->size(),
getInput().getType()->size(), this);
isValid &= checkTypeIgnoreShape(getResult(), getInput(), this);
return isValid;
}

bool SplatNode::verify() const { return true; }

bool TraceEventNode::verify() const { return true; }
Expand Down
176 changes: 96 additions & 80 deletions lib/Optimizer/GraphOptimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,92 +186,25 @@ static Node *simplifyNode(Node *node, Function *F) {
return node;
}

/// Parameters that are used to define ChannelShuffle operators.
struct ChannelShuffleParams {
size_t group;
size_t kernel;
};

/// Compute the original parameters to the ChannelShuffle operator (represented
/// as ReshapeNode->TransposeNode->ReshapeNode) for which \p node is the leading
/// ReshapeNode. \returns The original ChannelShuffle parameters if possible and
/// empty Optional otherwise.
static llvm::Optional<ChannelShuffleParams>
getChannelShuffleParams(const ReshapeNode &node) {
auto resM = llvm::Optional<ChannelShuffleParams>();

llvm::ArrayRef<size_t> inputDims = node.getInput().dims();
llvm::ArrayRef<size_t> resultDims = node.getDims();

// Check that there is one more output dimension than input dimension.
if (resultDims.size() != inputDims.size() + 1) {
return resM;
}

// Find the first output dimension that doesn't match its corresponding input
// dimension.
ChannelShuffleParams params;
bool found = false;
for (size_t i = 0, e = resultDims.size(); i < e - 1; ++i) {
if (inputDims[i] != resultDims[i]) {
params.kernel = i;
params.group = resultDims[i];
found = true;
break;
}
}

// Double check the property that the mismatched output found dimension and
// its successor together evenly multiply to the input dimension they
// mismatched on.
if (found && resultDims[params.kernel] * resultDims[params.kernel + 1] ==
inputDims[params.kernel]) {
resM = params;
}

return resM;
}

/// Sink Transpose below ChannelShuffle node sequence ending with \p
/// postShuffleRN. For example (Transpose_1->Reshape_1->Transpose_2->Reshape_2)
/// becomes (Reshape_1->Transpose_2->Reshape_2->Transpose_1). \returns true if
/// tranpose was sunk below ChannelShuffle node sequence and false otherwise.
/// Sink Transpose below ChannelShuffle node.
static bool sinkTranposeBelowChannelShuffle(Function *F,
ReshapeNode *postShuffleRN) {
auto *shuffleTR = dyn_cast<TransposeNode>(postShuffleRN->getInput());
if (!shuffleTR) {
ChannelShuffleNode *CS) {
auto *TR = dyn_cast<TransposeNode>(CS->getInput());
if (!TR) {
return false;
}

auto *preShuffleRN = dyn_cast<ReshapeNode>(shuffleTR->getInput());
if (!preShuffleRN) {
return false;
}

auto *sinkingTR = dyn_cast<TransposeNode>(preShuffleRN->getInput());
if (!sinkingTR) {
return false;
}

// Compute the original parameters to ChannelShuffle.
auto paramsM = getChannelShuffleParams(*preShuffleRN);

if (!paramsM.hasValue()) {
return false;
}

// Create a new ChannelShuffle with kernel parameter tranposed by the
// sinkingTR's shuffle because that Transpose will now be moved below this
// Create a new ChannelShuffle with kernel parameter transposed by the
// sinking TR's shuffle because that Transpose will now be moved below this
// ChannelShuffle operator.
auto *newChannelShuffle = F->createChannelShuffle(
"channel_shuffle", sinkingTR->getInput(), paramsM->group,
sinkingTR->getShuffle()[paramsM->kernel]);
auto *newCS =
F->createChannelShuffle(CS->getName(), TR->getInput(), CS->getGroup(),
TR->getShuffle()[CS->getKernel()]);

// Create a copy of sinkingTR and insert after newChannelShuffle.
auto *newSinkingTR = F->createTranspose(
sinkingTR->getName(), newChannelShuffle, sinkingTR->getShuffle());
auto *newTR = F->createTranspose(TR->getName(), newCS, TR->getShuffle());

postShuffleRN->getResult().replaceAllUsesOfWith(newSinkingTR);
CS->getResult().replaceAllUsesOfWith(newTR);

return true;
}
Expand Down Expand Up @@ -442,9 +375,9 @@ static bool sinkCode(Function *F) {
continue;
}

if (auto *RN = dyn_cast<ReshapeNode>(node)) {
if (auto *CS = dyn_cast<ChannelShuffleNode>(node)) {
// Sink Transpose below ChannelShuffle.
if (sinkTranposeBelowChannelShuffle(F, RN)) {
if (sinkTranposeBelowChannelShuffle(F, CS)) {
changed = true;
continue;
}
Expand Down Expand Up @@ -2512,6 +2445,86 @@ static void foldLeakyRelu(Function *F) {
}
}

/// Parameters that are used to define ChannelShuffle operators.
struct ChannelShuffleParams {
size_t group;
size_t kernel;
};

/// Compute the original parameters to the ChannelShuffle operator (represented
/// as ReshapeNode->TransposeNode->ReshapeNode) for which \p node is the leading
/// ReshapeNode. \returns The original ChannelShuffle parameters if possible and
/// empty Optional otherwise.
static llvm::Optional<ChannelShuffleParams>
getChannelShuffleParams(const ReshapeNode &node) {
auto resM = llvm::Optional<ChannelShuffleParams>();

llvm::ArrayRef<size_t> inputDims = node.getInput().dims();
llvm::ArrayRef<size_t> resultDims = node.getDims();

// Check that there is one more output dimension than input dimension.
if (resultDims.size() != inputDims.size() + 1) {
return resM;
}

// Find the first output dimension that doesn't match its corresponding input
// dimension.
ChannelShuffleParams params;
bool found = false;
for (size_t i = 0, e = resultDims.size(); i < e - 1; ++i) {
if (inputDims[i] != resultDims[i]) {
params.kernel = i;
params.group = resultDims[i];
found = true;
break;
}
}

// Double check the property that the mismatched output found dimension and
// its successor together evenly multiply to the input dimension they
// mismatched on.
if (found && resultDims[params.kernel] * resultDims[params.kernel + 1] ==
inputDims[params.kernel]) {
resM = params;
}

return resM;
}

// Fold Reshape->Transpose->Reshape into ChannelShuffle when applicable.
static void foldChannelShuffle(Function *F) {

auto &nodes = F->getNodes();
for (auto &node : nodes) {
auto *RN2 = dyn_cast<ReshapeNode>(&node);
if (!RN2) {
continue;
}

auto *TR = dyn_cast<TransposeNode>(RN2->getInput());
if (!TR) {
continue;
}

auto *RN1 = dyn_cast<ReshapeNode>(TR->getInput());
if (!RN1) {
continue;
}

// Compute the original parameters to ChannelShuffle.
auto paramsM = getChannelShuffleParams(*RN1);
if (!paramsM.hasValue()) {
continue;
}

// Create a new ChannelShuffle with kernel parameter tranposed by the
// TR's shuffle.
auto *newCS = F->createChannelShuffle("channel_shuffle", RN1->getInput(),
paramsM->group, paramsM->kernel);
RN2->getResult().replaceAllUsesOfWith(newCS);
}
}

void glow::fold(Function *F, const CompilationOptions &opts) {
(void)opts;
// Get Reshape nodes merged into constants to simplify folding.
Expand All @@ -2520,6 +2533,9 @@ void glow::fold(Function *F, const CompilationOptions &opts) {
// Fold sub-graphs corresponding to leakyRelu.
foldLeakyRelu(F);

// Fold Reshape->Transpose->Reshape into ChannelShuffle when applicable.
foldChannelShuffle(F);

// Perform Dead Code Elimination.
DCE(F);
}
Expand Down
33 changes: 33 additions & 0 deletions lib/Optimizer/Lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,37 @@ static void lowerTileNode(Function *F, LoweredInfoMap *loweredMap,
replaceAllUsesOfWith(loweredMap, TN.getResult(), IN);
}

static void lowerChannelShuffleNode(Function *F, LoweredInfoMap *loweredMap,
const ChannelShuffleNode &CSN) {
auto input = CSN.getInput();
auto group = CSN.getGroup();
auto kernel = CSN.getKernel();

auto inDims = input.dims();
assert(kernel < inDims.size());

ShapeVector dims(inDims.begin(), inDims.end());
auto D = dims[kernel];
assert(D % group == 0);

dims.erase(dims.begin() + kernel);
// Reshape {D1, ... D_k, ... D_n} -> {D1, ... group, D_k / group, ... D_n}
dims.insert(dims.begin() + kernel, D / group);
dims.insert(dims.begin() + kernel, group);
auto *R1 = F->createReshape(CSN.getName().str() + ".reshape1", input, dims);

std::vector<unsigned_t> transpose(dims.size());
for (size_t i = 0; i < transpose.size(); i++) {
transpose[i] = i;
}
std::swap(transpose[kernel], transpose[kernel + 1]);
auto *T =
F->createTranspose(CSN.getName().str() + ".transpose", R1, transpose);

auto *R2 = F->createReshape(CSN.getName().str() + ".reshape2", T, inDims);
replaceAllUsesOfWith(loweredMap, CSN.getResult(), R2);
}

static void lowerBatchReduceMeanNode(Function *F, LoweredInfoMap *loweredMap,
const BatchedReduceMeanNode &BRM) {
auto input = BRM.getBatch();
Expand Down Expand Up @@ -731,6 +762,8 @@ static void lowerNode(Function *F, Node *node, LoweredInfoMap *loweredMap) {
lowerGroupConvolutionNode(F, loweredMap, *CN);
} else if (auto *TN = dyn_cast<TileNode>(node)) {
lowerTileNode(F, loweredMap, *TN);
} else if (auto *CSN = dyn_cast<ChannelShuffleNode>(node)) {
lowerChannelShuffleNode(F, loweredMap, *CSN);
}
}

Expand Down
Loading

0 comments on commit d392571

Please sign in to comment.