Skip to content

Commit

Permalink
[PyTorch Edge] Support default args with out arg, flag off (pytorch#6…
Browse files Browse the repository at this point in the history
…3540)

Summary:
1. Allow consuming operators with defaults arguments and out arguments. Flag is off to keep the same behavior as v6, in pr 63651, turn on the flag.
2. Add two unittests to cover this type of operators.

Pull Request resolved: pytorch#63540

ghstack-source-id: 137211562

Test Plan:
```
caffe2/test/cpp/jit:jit - LiteInterpreterTest.DefaultArgsWithOutArg
caffe2/test/cpp/jit:jit - LiteInterpreterTest.DefaultArgsPinvWithOutArg
```

Reviewed By: raziel, iseeyuan, tugsbayasgalan

Differential Revision: D30414156

fbshipit-source-id: 0f3a219a22aee10ac53184cbd95940726c459d1f
  • Loading branch information
cccclai authored and facebook-github-bot committed Sep 2, 2021
1 parent 0addd75 commit 8d5b950
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 27 deletions.
2 changes: 1 addition & 1 deletion caffe2/serialize/versions.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ static_assert(kProducedBytecodeVersion >= kProducedFileFormatVersion,
// we should support this model_version. For example, we provide a wrapper to
// handle an updated operator.
constexpr uint64_t kMinSupportedBytecodeVersion = 0x3L;
constexpr uint64_t kMaxSupportedBytecodeVersion = 0x6L;
constexpr uint64_t kMaxSupportedBytecodeVersion = 0x7L;

} // namespace serialize
} // namespace caffe2
62 changes: 62 additions & 0 deletions test/cpp/jit/test_lite_interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1035,6 +1035,68 @@ TEST(LiteInterpreterTest, DefaultArgsPinvSpecifyDefault) {
testLiteModuleCompareResultTensors(m, inputs);
}

void testDefaultArgsPinvWithOutArg(int num_args) {
Module m("m");
if (num_args == 1) {
m.define(R"(
def forward(self, input):
return torch.linalg_pinv(input, out=input)
)");
} else if (num_args == 2) {
m.define(R"(
def forward(self, input):
return torch.linalg_pinv(input, 1e-5, out=input)
)");
} else if (num_args == 3) {
m.define(R"(
def forward(self, input):
return torch.linalg_pinv(input, 1e-5, True, out=input)
)");
}

const int N = 28;
auto input = torch::range(1, N * N, 1);
input[0] = 10000; // a more stable matrix
input = input.view({N, N});
auto ref = m.run_method("forward", input);
TORCH_CHECK(!input.equal(torch::range(1, N * N, 1)));
TORCH_CHECK(input.equal(ref.toTensor()));
}

TEST(LiteInterpreterTest, DefaultArgsPinvWithOutArg) {
// Test with different number of specified arguments + out arg.
// Arguments not specified take default value.
for (int num_args = 1; num_args <= 3; ++num_args) {
testDefaultArgsPinvWithOutArg(num_args);
}
}

TEST(LiteInterpreterTest, DefaultArgsWithOutArg) {
Module m("m");
m.define(R"(
def forward(self, x, h):
torch.add(x, h, out=x)
)");

std::vector<IValue> inputs;
auto input_x = 2 * torch::ones({});
auto input_h = torch::ones({});
auto ref = m.run_method("forward", input_x, input_h);

std::stringstream ss;

m._save_for_mobile(ss, {}, true);
mobile::Module bc = _load_for_mobile(ss);
bc.run_method("forward", input_x, input_h);
AT_ASSERT(input_x.equal(4 * torch::ones({})));

auto ops = _get_model_ops_and_info(ss);
auto op = ops.find("aten::add.out");
TORCH_CHECK(
op != ops.end() && op->second.num_schema_args.has_value() &&
op->second.num_schema_args.value() == 4);
}

TEST(LiteInterpreterTest, TestExceptionStackWithTwoLevelModuleHierarchy) {
Module a("A");
a.define(R"(
Expand Down
38 changes: 26 additions & 12 deletions torch/csrc/jit/mobile/function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,21 +99,35 @@ bool Function::append_operator(
// from model. We can use it to handle backward compatibility.
if (num_specified_args &&
num_specified_args.value() < static_cast<int64_t>(args.size())) {
// Sanity check at load time, to save perf at runtime
for (size_t i = num_specified_args.value(); i < args.size(); ++i) {
auto default_val = args[i].default_value();
TORCH_CHECK(
default_val.has_value(),
"Error happened at preparing for default values for the argument. The ",
i,
"th arguement of operator",
opname,
" does not have a specified value or default value. ");
}
fn = [fn, num_specified_args, args](Stack& stack) {
for (size_t i = num_specified_args.value(); i < args.size(); ++i) {
std::vector<IValue> out_args;
// The following logic pops and temporarily stores all out arguments
// from the stack (which can be 0 or more, and always appended to the
// schema), in order to push the necessary default values. Finally, the
// out arguments are pushed back into the stack.
for (size_t i = args.size() - 1; i > 0 && args.at(i).is_out(); i--) {
out_args.push_back(stack.back());
stack.pop_back();
}
size_t start_index = num_specified_args.value() - out_args.size();
TORCH_CHECK(
start_index >= 0,
"The number of output arguments is: ",
out_args.size(),
", which is more then the number of specified arguments: ",
num_specified_args.value());
for (size_t i = start_index; i < (args.size() - out_args.size()); ++i) {
TORCH_CHECK(
args[i].default_value().has_value(),
"Error happened at preparing for default values for the argument. The ",
i,
"th argument ",
args[i].name(),
" does not have a specified value or default value. ");

stack.push_back(args[i].default_value());
}
stack.insert(stack.end(), out_args.rbegin(), out_args.rend());
fn(stack);
};
}
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/runtime/interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -978,11 +978,13 @@ MobileCode::MobileCode(
const std::shared_ptr<Graph>& graph,
std::string function_name,
bool emit_default_input_instructions,
bool support_default_args_before_out,
size_t remaining_bailout_depth)
: Code(new interpreter::MobileCodeImpl(
graph,
std::move(function_name),
emit_default_input_instructions,
support_default_args_before_out,
remaining_bailout_depth)) {}

MobileCode::~MobileCode() = default;
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/runtime/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ struct TORCH_API MobileCode : Code {
const std::shared_ptr<Graph>& graph,
std::string function_name,
bool emit_default_input_instructions = true,
bool support_default_args_before_out = false,
size_t remaining_bailout_depth = 0);
~MobileCode();
};
Expand Down
37 changes: 23 additions & 14 deletions torch/csrc/jit/runtime/interpreter/code_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -721,9 +721,11 @@ struct MobileCodeImpl : CodeImpl {
const std::shared_ptr<Graph>& graph,
std::string function_name,
bool emit_default_input_instructions,
bool support_default_args_before_out,
size_t remaining_bailout_depth)
: CodeImpl(graph, function_name, remaining_bailout_depth, false),
emit_default_input_instructions_(emit_default_input_instructions) {
emit_default_input_instructions_(emit_default_input_instructions),
support_default_args_before_out_(support_default_args_before_out) {
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
run();
}
Expand All @@ -746,11 +748,12 @@ struct MobileCodeImpl : CodeImpl {
// skip if schema has vararg
if (!op_schema.is_vararg()) {
auto specifiedArgs = CalculateNecessaryArgs(
op_schema.arguments(), node->inputs(), false);
// preserving the old behavior
auto numInclude = specifiedArgs.first;
// TODO uncomment this
// auto numInclude = specifiedArgs.first + specifiedArgs.second;
op_schema.arguments(),
node->inputs(),
support_default_args_before_out_);

size_t numInclude = specifiedArgs.first +
(support_default_args_before_out_ ? specifiedArgs.second : 0);
auto unique_name = op_schema.overload_name() != ""
? op_schema.name() + "." + op_schema.overload_name()
: op_schema.name();
Expand Down Expand Up @@ -782,21 +785,27 @@ struct MobileCodeImpl : CodeImpl {
if (it != op_to_num_specified_args_.end()) {
num_include = it->second;
}
emitLoadInputs(node->inputs(), num_include);
// TODO: uncomment this
// auto num_out = op_to_num_out_args_.find(unique_op_name)->second;
// auto num_specified_before_out = num_include - num_out;
// emitLoadInputs(node->inputs(), 0, num_specified_before_out);
// emitLoadInputs(node->inputs(), node->inputs().size() - num_out,
// node->inputs().size());

if (support_default_args_before_out_) {
auto num_out = op_to_num_out_args_.find(unique_op_name)->second;
auto num_specified_before_out = num_include - num_out;
emitLoadInputs(node->inputs(), 0, num_specified_before_out);
emitLoadInputs(
node->inputs(),
node->inputs().size() - num_out,
node->inputs().size());
} else {
emitLoadInputs(node->inputs(), num_include);
}
insertInstruction(OP, operator_table_.size());
}
operator_table_.emplace_back(op.getOperation(node));
}
}

// To support forward compatibility for bytecode version bump from v5 to v6
bool emit_default_input_instructions_;
// To support forward compatibility for bytecode version bump from v6 to v7
bool support_default_args_before_out_;
};

} // namespace interpreter
Expand Down

0 comments on commit 8d5b950

Please sign in to comment.