Skip to content

Commit

Permalink
Various fixes (pytorch#3895)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#3895

Some fixes for models with EmbeddingBagByteRowwiseOffsets

Reviewed By: qizzzh

Differential Revision: D18940211

fbshipit-source-id: 0a1b9771602c6465e04d246c1a3b57a74366694e
  • Loading branch information
jackm321 authored and facebook-github-bot committed Dec 12, 2019
1 parent 4bfe34e commit 3e8854f
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 2 deletions.
3 changes: 3 additions & 0 deletions torch_glow/src/GlowFuser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <torch/csrc/jit/passes/alias_analysis.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>

Expand Down Expand Up @@ -269,6 +270,8 @@ void glowCustomFuseImpl(std::shared_ptr<torch::jit::Graph> graph,
return fn(ptNode);
};

Inline(*graph);

// Prepare the graph by fusing known patterns for the model loader.
// TODO: this should be done only on Glow subgraphs to avoid modifying parts
// of the graph that Glow will not be running.
Expand Down
30 changes: 28 additions & 2 deletions torch_glow/src/PyTorchModelLoader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,7 @@ PyTorchModelLoader::getSymbolsMapping() {
{{"aten::gelu"}, &PyTorchModelLoader::loadGelu, {}},
{{"aten::tanh", "aten::tanh_"}, &PyTorchModelLoader::loadTanh, {}},
{{"aten::t", "aten::t_"}, &PyTorchModelLoader::loadT, {}},
{{"aten::to"}, &PyTorchModelLoader::loadTo, {}},
{{"aten::permute"}, &PyTorchModelLoader::loadPermute, {}},
{{"aten::transpose", "aten::transpose_"},
&PyTorchModelLoader::loadTranspose,
Expand Down Expand Up @@ -3060,6 +3061,18 @@ Error PyTorchModelLoader::loadPermute(const torch::jit::Node *ptNode) {
return addValueMapping(outputs[0], output);
}

Error PyTorchModelLoader::loadTo(const torch::jit::Node *ptNode) {
auto inputs = ptNode->inputs();
auto outputs = ptNode->outputs();
RETURN_IF_ERR(checkInputAndOutputSizes(inputs, 5, outputs, 1));

// TODO: use ConvertTo
glow::NodeValue in;
ASSIGN_VALUE_OR_RETURN_ERR(in, getGlowNodeValueForValue(inputs[0]));

return addValueMapping(outputs[0], in);
}

Error PyTorchModelLoader::loadFlatten(const torch::jit::Node *ptNode) {
auto inputs = ptNode->inputs();
auto outputs = ptNode->outputs();
Expand Down Expand Up @@ -3284,6 +3297,19 @@ Error PyTorchModelLoader::loadEmbeddingBagByteRowwiseOffsets(
"EmbeddingBagByteRowwiseOffsets.ones",
glow::Type(weight.getElementType(), {indices.dims()[0]}), 1.0);

glow::Constant *weightConstant =
llvm::dyn_cast<glow::Constant>(weight.getNode());

RETURN_ERR_IF_NOT(weightConstant,
strFormat("Expected Weight to be a Constant but found: %s",
weight.getNode()->getKindName()));

TypeRef fusedTy = F_.getParent()->uniqueType(ElemKind::UInt8FusedQTy,
weight.dims(), 0.0, 0);

weightConstant->setType(Storage::OutputIdx, fusedTy);
weightConstant->setPayloadType(fusedTy);

bool scaleGradByFreq;
ASSIGN_VALUE_OR_RETURN_ERR(
scaleGradByFreq,
Expand All @@ -3308,8 +3334,8 @@ Error PyTorchModelLoader::loadEmbeddingBagByteRowwiseOffsets(
RETURN_ERR_IF_NOT(sparse == false, "Currently only support sparse='false'");

auto *EB = F_.createEmbeddingBagByteRowwiseOffsets(
"EmbeddingBagByteRowwiseOffsets", weight, perSampleWeights, indices,
offsets);
"EmbeddingBagByteRowwiseOffsets", weightConstant->getOutput(),
perSampleWeights, indices, offsets);

return addValueMapping(outputs[0], EB->getResult());
}
Expand Down
4 changes: 4 additions & 0 deletions torch_glow/src/PyTorchModelLoader.h
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,10 @@ class PyTorchModelLoader {
/// Load a PyTorch aten::permute node.
/// \returns error on failure.
Error loadPermute(const torch::jit::Node *ptNode);

/// Load a PyTorch aten::to node.
/// \returns error on failure.
Error loadTo(const torch::jit::Node *ptNode);
};

} // namespace glow
Expand Down

0 comments on commit 3e8854f

Please sign in to comment.