Skip to content

Commit

Permalink
[XLA] Dump and parse the entry_computation_layout in string format, p…
Browse files Browse the repository at this point in the history
…roto format already serialized this information

PiperOrigin-RevId: 450528372
  • Loading branch information
blakehechtman authored and tensorflower-gardener committed May 23, 2022
1 parent 58c6888 commit 21754c7
Show file tree
Hide file tree
Showing 9 changed files with 237 additions and 153 deletions.
2 changes: 2 additions & 0 deletions tensorflow/compiler/xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -5672,11 +5672,13 @@ cc_library(
srcs = ["hlo_parser.cc"],
hdrs = ["hlo_parser.h"],
deps = [
":computation_layout",
":hlo",
":hlo_lexer",
":shape_inference",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_layout",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/xla/service/computation_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ std::string ComputationLayout::ToString() const {
for (auto& param_layout : parameter_layouts_) {
params.push_back(param_layout.ToString());
}
return absl::StrCat("(", absl::StrJoin(params, ", "), ") => ",
return absl::StrCat("(", absl::StrJoin(params, ","), ")->",
result_layout_.ToString());
}

Expand Down
4 changes: 2 additions & 2 deletions tensorflow/compiler/xla/service/hlo_instruction_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1634,7 +1634,7 @@ TEST_F(HloInstructionTest, StringifyAsyncOps) {
module->AddEmbeddedComputation(std::move(async_computation));

const std::string expected_with_syntax_sugar =
R"(HloModule StringifyAsyncOps
R"(HloModule StringifyAsyncOps, entry_computation_layout={(f32[10]{0})->f32[20]{0}}
ENTRY %Entry (p0: f32[10]) -> f32[20] {
%p0 = f32[10]{0} parameter(0)
Expand All @@ -1646,7 +1646,7 @@ ENTRY %Entry (p0: f32[10]) -> f32[20] {
)";
EXPECT_EQ(module->ToString(), expected_with_syntax_sugar);
const std::string expected_without_syntax_sugar =
R"(HloModule StringifyAsyncOps
R"(HloModule StringifyAsyncOps, entry_computation_layout={(f32[10]{0})->f32[20]{0}}
%AsyncOp (p0.1: f32[10]) -> f32[20] {
%p0.1 = f32[10]{0} parameter(0)
Expand Down
6 changes: 6 additions & 0 deletions tensorflow/compiler/xla/service/hlo_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,12 @@ absl::Cord HloModule::ToCord(const HloPrintOptions& options) const {
if (config_.alias_passthrough_params()) {
result.Append(", alias_passthrough_params=true");
}
if (config_.has_entry_computation_layout()) {
LOG(ERROR) << "HAS CONFIG " << this->name();
result.Append(", entry_computation_layout={");
result.Append(entry_computation_layout().ToString());
result.Append("}");
}
if (config_.allow_spmd_sharding_propagation_to_output()) {
result.Append(", allow_spmd_sharding_propagation_to_output=true");
}
Expand Down
5 changes: 5 additions & 0 deletions tensorflow/compiler/xla/service/hlo_module_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ class HloModuleConfig {
return &(*entry_computation_layout_);
}

// Clears the entry computation layout.
void clear_entry_computation_layout() {
entry_computation_layout_ = absl::nullopt;
}

// Returns whether to enable HLO-level profiling.
bool hlo_profiling_enabled() const {
return debug_options_.xla_hlo_profile();
Expand Down
12 changes: 7 additions & 5 deletions tensorflow/compiler/xla/service/hlo_module_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,16 @@ TEST_F(HloModuleTest, LargeConstantToString) {
module->AddEntryComputation(builder.Build());

EXPECT_EQ(
"HloModule LargeConstantToString\n\nENTRY %Constant () -> f32[16] {\n "
"ROOT %constant = f32[16]{0} constant({...})\n}\n\n",
"HloModule LargeConstantToString, "
"entry_computation_layout={()->f32[16]{0}}\n\nENTRY %Constant () -> "
"f32[16] {\n ROOT %constant = f32[16]{0} constant({...})\n}\n\n",
module->ToString(HloPrintOptions().set_print_large_constants(false)));

EXPECT_EQ(
"HloModule LargeConstantToString\n\nENTRY %Constant () -> f32[16] {\n "
"ROOT %constant = f32[16]{0} constant({42, 42, 42, 42, 42, 42, 42, 42, "
"42, 42, 42, 42, 42, 42, 42, 42})\n}\n\n",
"HloModule LargeConstantToString, "
"entry_computation_layout={()->f32[16]{0}}\n\nENTRY %Constant () -> "
"f32[16] {\n ROOT %constant = f32[16]{0} constant({42, 42, 42, 42, 42, "
"42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42})\n}\n\n",
module->ToString(HloPrintOptions().set_print_large_constants(true)));
}

Expand Down
70 changes: 69 additions & 1 deletion tensorflow/compiler/xla/service/hlo_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
Expand Down Expand Up @@ -266,6 +267,7 @@ class HloParserImpl : public HloParser {
kEnum,
kRandomAlgorithm,
kAliasing,
kComputationLayout,
kInstructionAliasing,
kCustomCallSchedule,
kCustomCallApiVersion,
Expand Down Expand Up @@ -509,6 +511,9 @@ class HloParserImpl : public HloParser {
// fails.
bool ParseAliasing(AliasingData* data);

// Parses the entry computation layout.
bool ParseComputationLayout(ComputationLayout* computation_layout);

// Parses the per-instruction aliasing information from string `s`, returns
// `false` if it fails.
bool ParseInstructionOutputOperandAliasing(
Expand Down Expand Up @@ -772,6 +777,48 @@ bool HloParserImpl::ParseAliasing(AliasingData* data) {
return true;
}

bool HloParserImpl::ParseComputationLayout(
ComputationLayout* computation_layout) {
if (!ParseToken(TokKind::kLbrace,
"Expects '{' at the start of aliasing description")) {
return false;
}
if (!ParseToken(TokKind::kLparen, "Expects ( before parameter shape list")) {
return false;
}
while (lexer_.GetKind() != TokKind::kRparen) {
Shape param;
if (!ParseShape(&param)) {
return false;
}
computation_layout->add_parameter_layout(ShapeLayout(param));
if (lexer_.GetKind() == TokKind::kRparen) {
break;
}
if (!ParseToken(TokKind::kComma, "Expects , between parameter shapes")) {
return false;
}
}

if (!ParseToken(TokKind::kRparen,
"Expects ) at end of parameter shape list")) {
return false;
}
if (!ParseToken(TokKind::kArrow, "Expects -> before result shape")) {
return false;
}
Shape result;
if (!ParseShape(&result)) {
return false;
}
*computation_layout->mutable_result_layout() = ShapeLayout(result);
if (!ParseToken(TokKind::kRbrace,
"Expects '}' at the end of computation layouts")) {
return false;
}
return true;
}

bool HloParserImpl::ParseInstructionOutputOperandAliasing(
std::vector<std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>*
aliasing_output_operand_pairs) {
Expand Down Expand Up @@ -876,12 +923,16 @@ bool HloParserImpl::ParseHloModule(HloModule* module) {
absl::optional<AliasingData> aliasing_data;
absl::optional<bool> alias_passthrough_params;
absl::flat_hash_map<std::string, AttrConfig> attrs;
absl::optional<ComputationLayout> entry_computation_layout;

attrs["is_scheduled"] = {/*required=*/false, AttrTy::kBool, &is_scheduled};
attrs["input_output_alias"] = {/*required=*/false, AttrTy::kAliasing,
&aliasing_data};
attrs["alias_passthrough_params"] = {/*required=*/false, AttrTy::kBool,
&alias_passthrough_params};
attrs["entry_computation_layout"] = {/*required=*/false,
AttrTy::kComputationLayout,
&entry_computation_layout};
if (!ParseAttributes(attrs)) {
return false;
}
Expand All @@ -893,9 +944,17 @@ bool HloParserImpl::ParseHloModule(HloModule* module) {
if (is_scheduled.has_value() && *is_scheduled) {
TF_CHECK_OK(module->set_schedule(ScheduleFromInstructionOrder(module)));
}
HloModuleConfig config = module->config();
bool default_config = true;
if (alias_passthrough_params.has_value() && *alias_passthrough_params) {
HloModuleConfig config = module->config();
config.set_alias_passthrough_params(true);
default_config = false;
}
if (entry_computation_layout.has_value()) {
*config.mutable_entry_computation_layout() = *entry_computation_layout;
default_config = false;
}
if (!default_config) {
module->set_config(config);
}
if (aliasing_data) {
Expand Down Expand Up @@ -4278,6 +4337,15 @@ bool HloParserImpl::ParseAttributeHelper(
->emplace(aliasing_data);
return true;
}
case AttrTy::kComputationLayout: {
ComputationLayout computation_layout(ShapeLayout(Shape{}));
if (!ParseComputationLayout(&computation_layout)) {
return false;
}
static_cast<optional<ComputationLayout>*>(attr_out_ptr)
->emplace(computation_layout);
return true;
}
case AttrTy::kInstructionAliasing: {
std::vector<std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
aliasing_output_operand_pairs;
Expand Down
Loading

0 comments on commit 21754c7

Please sign in to comment.