Skip to content

Commit

Permalink
[XLA:GPU][NFC] Fix style in in_place_dynamic_update_slice_mlir.cc for…
Browse files Browse the repository at this point in the history
… consistency

PiperOrigin-RevId: 618199494
  • Loading branch information
tyb0807 authored and tensorflower-gardener committed Mar 22, 2024
1 parent f0ae658 commit 5abfc19
Showing 1 changed file with 42 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ limitations under the License.
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/IR/ValueRange.h" // from @llvm-project
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
Expand All @@ -46,6 +47,20 @@ namespace xla {
namespace gpu {
namespace {

using llvm::SmallVector;
using mlir::ImplicitLocOpBuilder;
using mlir::MLIRContext;
using mlir::Value;
using mlir::ValueRange;
using mlir::arith::AddIOp;
using mlir::func::ReturnOp;
using mlir::tensor::InsertOp;
using mlir_converter::ApplyAffineMap;
using mlir_converter::CallTargetProvider;
using mlir_converter::ClampIndex;
using mlir_converter::PartitionedComputations;
using mlir_converter::ProvideParameter;

constexpr int kDUSUpdateIndex = 1;

} // namespace
Expand Down Expand Up @@ -81,14 +96,13 @@ MlirInPlaceDynamicUpdateSliceFusion::GetInstructionsWithCustomCodegen(
}

absl::Status MlirInPlaceDynamicUpdateSliceFusion::EmitEntryFunction(
const mlir_converter::PartitionedComputations& computations,
const mlir_converter::CallTargetProvider& call_targets,
mlir::func::FuncOp entry_function,
const PartitionedComputations& computations,
const CallTargetProvider& call_targets, mlir::func::FuncOp entry_function,
const HloFusionInstruction& fusion) const {
mlir::ImplicitLocOpBuilder b(entry_function.getLoc(), entry_function);
ImplicitLocOpBuilder b(entry_function.getLoc(), entry_function);
b.setInsertionPointToStart(entry_function.addEntryBlock());

mlir::MLIRContext* mlir_context = entry_function.getContext();
MLIRContext* mlir_context = entry_function.getContext();
IndexingContext indexing_context{mlir_context};

auto indexing = *ComputeThreadIdToInputIndexing(
Expand All @@ -105,40 +119,45 @@ absl::Status MlirInPlaceDynamicUpdateSliceFusion::EmitEntryFunction(
fusion.fused_instructions_computation());
const auto& dus_subgraph = root_computation.FindSubgraph(dus_ops_.front());

const auto* dus_instr = dus_ops_.front();
const auto* dus_instr =
Cast<HloDynamicUpdateSliceInstruction>(dus_ops_.front());
const auto& update_shape = dus_instr->operand(kDUSUpdateIndex)->shape();
auto result_tensors = EmitThreadLoopNest(
b, output_tensor_args, indexing,
[&](mlir::ValueRange output_tensors, mlir::ValueRange dim_values,
mlir::ValueRange symbol_values) -> llvm::SmallVector<mlir::Value> {
auto input_indices = mlir_converter::ApplyAffineMap(
indexing.GetAffineMap(), dim_values, symbol_values, b);
llvm::SmallVector<mlir::Value> update_indices;
[&](ValueRange output_tensors, ValueRange dim_values,
ValueRange symbol_values) -> llvm::SmallVector<Value> {
auto input_indices = ApplyAffineMap(indexing.GetAffineMap(), dim_values,
symbol_values, b);
SmallVector<Value> update_indices;
for (int i = 0; i < update_shape.rank(); ++i) {
int64_t update_size = update_shape.dimensions(i);
auto start_index = mlir_converter::ProvideParameter(
dus_subgraph, dus_instr, i + 2, {}, call_targets, entry_function,
b)[0];
start_index = mlir_converter::ClampIndex(
auto start_index =
ProvideParameter(dus_subgraph, dus_instr,
i + dus_instr->first_index_operand_number(), {},
call_targets, entry_function, b)[0];
start_index = ClampIndex(
start_index,
primitive_util::IsUnsignedIntegralType(
dus_instr->operand(i + 2)->shape().element_type()),
dus_instr
->operand(i + dus_instr->first_index_operand_number())
->shape()
.element_type()),
dus_instr->shape().dimensions(i) - update_size, b);

update_indices.push_back(
b.create<mlir::arith::AddIOp>(input_indices[i], start_index));
b.create<AddIOp>(input_indices[i], start_index));
}

auto updated_value = mlir_converter::ProvideParameter(
dus_subgraph, dus_instr, kDUSUpdateIndex, input_indices,
call_targets, entry_function, b)[0];
auto insert = b.create<mlir::tensor::InsertOp>(
updated_value, output_tensors[0], update_indices);
auto updated_value =
ProvideParameter(dus_subgraph, dus_instr, kDUSUpdateIndex,
input_indices, call_targets, entry_function, b)[0];
auto insert = b.create<InsertOp>(updated_value, output_tensors[0],
update_indices);

return {insert.getResult()};
});

b.create<mlir::func::ReturnOp>(result_tensors);
b.create<ReturnOp>(result_tensors);
return absl::OkStatus();
}

Expand Down

0 comments on commit 5abfc19

Please sign in to comment.