Skip to content

Commit

Permalink
[XLA][MLIR] Remove data race in SelectAndScatterOp lowered to paralle…
Browse files Browse the repository at this point in the history
…l loops.

PiperOrigin-RevId: 307774393
Change-Id: I1a6e1892821533dbe3be3c31c6b4ea0455167f78
  • Loading branch information
pifon2a authored and tensorflower-gardener committed Apr 22, 2020
1 parent f7094d2 commit 5e2e46b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -371,10 +371,15 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
// CHECK: }

// Use selected ivs to load element from the SRC buffer.
// CHECK: [[CUR_RES:%.*]] = load [[RESULT_BUF]]{{\[}}[[SEL_RES_I:%.*]]#0,
// CHECK-SAME: [[SEL_RES_I]]#1] : memref<112x112xf32>
// CHECK: [[SRC_ELEM:%.*]] = load [[SRC_BUF]]{{\[}}[[II]], [[JJ]]]

// Update of RESULT[SELECTED_I, SELECTED_J] should be done atomically, because
// it may happen that several other threads select the same IVs if the windows
// overlap.
// CHECK: generic_atomic_rmw [[RESULT_BUF]]{{\[}}[[SEL_RES_I]]#0,
// CHECK-SAME: [[SEL_RES_I]]#1] : memref<112x112xf32>
// CHECK: ^bb0([[CUR_RES:%.*]]: f32):

// Allocate buffers for ARG element, current selected value to adapt LHLO code.
// CHECK: [[SRC_ELEM_BUF:%.*]] = alloc() : memref<f32>
// CHECK: [[CUR_RES_BUF:%.*]] = alloc() : memref<f32>
Expand All @@ -387,8 +392,8 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
// CHECK-SAME: (memref<f32>, memref<f32>, memref<f32>) -> ()
// CHECK: [[RES:%.*]] = load [[RES_BUF]][] : memref<f32>

// Update RESULT[SELECTED_I, SELECTED_J] with RES.
// CHECK: store [[RES]], [[RESULT_BUF]]{{\[}}[[SEL_RES_I]]#0, [[SEL_RES_I]]#1]
// Atomic RMW terminator that returns updated value.
// CHECK: atomic_yield [[RES]] : f32

// Parallel loop over source buffer yield
// CHECK: loop.yield
Original file line number Diff line number Diff line change
Expand Up @@ -488,23 +488,27 @@ class SelectAndScatterOpConverter
LogicalResult matchAndRewrite(
xla_lhlo::SelectAndScatterOp s_and_s_op, ArrayRef<Value> /*args*/,
ConversionPatternRewriter& rewriter) const final {
auto loc = s_and_s_op.getLoc();
InitializeOutput(s_and_s_op, &rewriter);
loop::ParallelOp loop_over_src =
MakeLoopOverShape(s_and_s_op.getLoc(), s_and_s_op.source(), &rewriter);
MakeLoopOverShape(loc, s_and_s_op.source(), &rewriter);
rewriter.setInsertionPointToStart(loop_over_src.getBody());

// Compute indices of the selected element in the window.
auto selected_ivs = SelectIvs(s_and_s_op, loop_over_src, &rewriter);
// Compute `acc_result` = scatter(out[selected_ivs], src_element)`.
Value acc_result =
Scatter(s_and_s_op, loop_over_src, selected_ivs, &rewriter);

// Updates `out[selected_ivs]`.
//
// TODO(pifon): This has to become AtomicRMWOp that updates an element of
// s_and_s_op.out().
rewriter.create<StoreOp>(s_and_s_op.getLoc(), acc_result, s_and_s_op.out(),
selected_ivs);
// Load `source[selected_ivs]`.
auto src_elem = rewriter.create<LoadOp>(loc, s_and_s_op.source(),
loop_over_src.getInductionVars());

// Compute `out[selected_ivs]` = scatter(out[selected_ivs], src_element)`.
auto rmw = rewriter.create<GenericAtomicRMWOp>(loc, s_and_s_op.out(),
selected_ivs);
OpBuilder rmw_builder = rmw.getBodyBuilder();
auto acc_result =
ApplySingleResultLhloCode(loc, {src_elem, rmw.getCurrentValue()},
&s_and_s_op.scatter().front(), &rmw_builder);
rmw_builder.create<AtomicYieldOp>(loc, acc_result);

rewriter.replaceOp(s_and_s_op, llvm::None);
return success();
Expand Down Expand Up @@ -685,19 +689,6 @@ class SelectAndScatterOpConverter
}
return if_init.getResults();
}

Value Scatter(xla_lhlo::SelectAndScatterOp s_and_s_op,
loop::ParallelOp loop_over_src, ValueRange selected_ivs,
OpBuilder* b) const {
auto loc = s_and_s_op.getLoc();

auto acc_current = b->create<LoadOp>(loc, s_and_s_op.out(), selected_ivs);
auto src_elem = b->create<LoadOp>(loc, s_and_s_op.source(),
loop_over_src.getInductionVars());

return ApplySingleResultLhloCode(loc, {src_elem, acc_current},
&s_and_s_op.scatter().front(), b);
}
};

struct LhloLegalizeToParallelLoops
Expand Down

0 comments on commit 5e2e46b

Please sign in to comment.