Skip to content

Commit

Permalink
Add new cost model to enable looped einsum and choose which collectiv…
Browse files Browse the repository at this point in the history
…e ops to decompose.

PiperOrigin-RevId: 407597684
Change-Id: I19c076167d5e8215401eeaa81cd17174683bbbc8
  • Loading branch information
tensorflower-gardener committed Nov 4, 2021
1 parent 58c4c67 commit eab602d
Show file tree
Hide file tree
Showing 6 changed files with 397 additions and 53 deletions.
20 changes: 15 additions & 5 deletions tensorflow/compiler/xla/service/hlo_cost_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,8 @@ Status HloCostAnalysis::HandleDomain(const HloInstruction* domain) {
return Status::OK();
}

Status HloCostAnalysis::HandleDot(const HloInstruction* dot) {
/* static */
int64_t HloCostAnalysis::GetDotFlops(const HloInstruction* dot) {
const Shape& lhs_shape = dot->operand(0)->shape();
const Shape& dot_shape = dot->shape();
const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
Expand All @@ -331,8 +332,11 @@ Status HloCostAnalysis::HandleDot(const HloInstruction* dot) {
reduction_width *= lhs_shape.dimensions(dim);
}
// Each output element requires reduction_width FMA operations.
current_properties_[kFlopsKey] =
kFmaFlops * ShapeUtil::ElementsIn(dot_shape) * reduction_width;
return kFmaFlops * ShapeUtil::ElementsIn(dot_shape) * reduction_width;
}

Status HloCostAnalysis::HandleDot(const HloInstruction* dot) {
current_properties_[kFlopsKey] = GetDotFlops(dot);
return Status::OK();
}

Expand Down Expand Up @@ -567,7 +571,9 @@ Status HloCostAnalysis::HandleAddDependency(
return Status::OK();
}

Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) {
/* static */
int64_t HloCostAnalysis::GetConvolutionFlops(
const HloInstruction* convolution) {
auto lhs = convolution->operand(0);
auto rhs = convolution->operand(1);
Window window = convolution->window();
Expand Down Expand Up @@ -688,7 +694,11 @@ Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) {
(input_feature / convolution->feature_group_count()) * output_feature *
(batch / convolution->batch_group_count()) *
Product(valid_position_counts);
current_properties_[kFlopsKey] = fma_count * kFmaFlops;
return fma_count * kFmaFlops;
}

Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) {
current_properties_[kFlopsKey] = GetConvolutionFlops(convolution);
return Status::OK();
}

Expand Down
6 changes: 6 additions & 0 deletions tensorflow/compiler/xla/service/hlo_cost_analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,12 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
ShapeIndex index = {});
static std::string GetOutputBytesAccessedKey(ShapeIndex index = {});

// Returns the estimated convolution flops.
static int64_t GetConvolutionFlops(const HloInstruction* convolution);

// Returns the estimated dot flops.
static int64_t GetDotFlops(const HloInstruction* dot);

protected:
typedef std::unordered_map<const HloInstruction*, Properties> HloToProperties;

Expand Down
4 changes: 3 additions & 1 deletion tensorflow/compiler/xla/service/hlo_instruction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3265,7 +3265,9 @@ void HloInstruction::set_tracing(HloInstruction* trace_instruction) {
trace_instruction_ = trace_instruction;
}

bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); }
bool HloInstruction::IsFused() const {
return parent_ != nullptr && parent_->IsFusionComputation();
}

bool HloInstruction::IsCustomCall(absl::string_view target) const {
return opcode() == HloOpcode::kCustomCall && custom_call_target() == target;
Expand Down
Loading

0 comments on commit eab602d

Please sign in to comment.