Skip to content

Commit

Permalink
Reverts c04aec9
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 698654038
  • Loading branch information
naummo authored and Google-ML-Automation committed Nov 21, 2024
1 parent 6568713 commit e72b449
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 61 deletions.
5 changes: 2 additions & 3 deletions jaxlib/mosaic/dialect/tpu/tpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -654,15 +654,14 @@ def TPU_SemaphoreSignalOp : TPU_Op<"sem_signal", [AttrSizedOperandSegments]> {
I32:$amount,
Optional<I32>:$device_id, // For remote DMAs
Optional<I32>:$core_id, // For megacore
Optional<I32>:$subcore_id, // For the SC vector subcore
OptionalAttr<TPU_CoreTypeEnum>:$core_type
);
let assemblyFormat = [{
$semaphore `,` $amount (`device_id` $device_id^)? (`core_id` $core_id^)? (`subcore_id` $subcore_id^)? (`core_type` $core_type^)? attr-dict `:` type($semaphore)
$semaphore `,` $amount (`device_id` $device_id^)? (`core_id` $core_id^)? (`core_type` $core_type^)? attr-dict `:` type($semaphore)
}];
let hasVerifier = 1;
let builders = [
// A backward-compatible builder that sets `subcore_id` and `core_type` to nullptr.
// A backward-compatible builder that sets `core_type` to nullptr.
OpBuilder<(ins "Value":$semaphore, "Value":$amount,
"Value":$device_id, "Value":$core_id)>,
];
Expand Down
36 changes: 9 additions & 27 deletions jaxlib/mosaic/dialect/tpu/tpu_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,7 @@ void SemaphoreSignalOp::build(OpBuilder &builder, OperationState &state,
Value semaphore, Value amount, Value device_id,
Value core_id) {
build(builder, state, semaphore, amount, device_id, core_id,
/*subcore_id=*/nullptr, /*core_type=*/nullptr);
/*core_type=*/nullptr);
}

LogicalResult SemaphoreSignalOp::verify() {
Expand All @@ -861,39 +861,21 @@ LogicalResult SemaphoreSignalOp::verify() {
CoreType issuing_core_type = issuing_core_type_maybe->value_or(CoreType::kTc);
CoreType target_core_type = getCoreType().value_or(issuing_core_type);

if (getCoreId() == nullptr && getDeviceId() == nullptr &&
getSubcoreId() == nullptr) {
if (getCoreId() == nullptr && getDeviceId() == nullptr) {
if (target_core_type != issuing_core_type) {
return emitOpError(absl::StrFormat(
"Target core type (%s) must match source core type "
"(%s) when device_id, core_id and subcore_id are not specified",
stringifyCoreType(target_core_type),
stringifyCoreType(issuing_core_type)));
return emitOpError(
absl::StrFormat("Target core type (%s) must match source core type "
"(%s) when device_id and core_id are not specified",
stringifyCoreType(target_core_type),
stringifyCoreType(issuing_core_type)));
}
}
if (target_core_type == CoreType::kScVectorSubcore &&
issuing_core_type != CoreType::kScVectorSubcore &&
getSubcoreId() == nullptr) {
return emitOpError(
"Subcore ID must be specified for the SC vector subcore");
}
if (target_core_type != CoreType::kScVectorSubcore &&
getSubcoreId() != nullptr) {
return emitOpError(
"Subcore ID must be specified only for the SC vector subcore");
}
if ((issuing_core_type == CoreType::kTc &&
(target_core_type == CoreType::kScScalarSubcore ||
target_core_type == CoreType::kScVectorSubcore)) ||
((issuing_core_type == CoreType::kScScalarSubcore ||
issuing_core_type == CoreType::kScVectorSubcore) &&
target_core_type == CoreType::kScScalarSubcore) ||
(issuing_core_type == CoreType::kScScalarSubcore &&
target_core_type == CoreType::kTc)) {
return emitOpError("Signalling between TC and SC is not implemented");
}
if (target_core_type == CoreType::kScVectorSubcore &&
(getCoreId() != nullptr || getDeviceId() != nullptr)) {
return emitOpError("Signalling remote SC vector subcores is not supported");
}
return success();
}

Expand Down
44 changes: 13 additions & 31 deletions jaxlib/mosaic/dialect/tpu/transforms/serde.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,19 @@ limitations under the License.

// We need to keep some extra headers for the code in tpu_passes.h.inc.

#include <cstdint>
#include <memory> // IWYU pragma: keep
#include <optional>
#include <string>
#include <string_view>

#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Pass/Pass.h" // IWYU pragma: keep
#include "mlir/Support/LLVM.h"
#include "absl/strings/str_format.h"
#include "mlir/include/mlir/IR/BuiltinAttributes.h"
#include "mlir/include/mlir/IR/OpDefinition.h"
#include "mlir/include/mlir/IR/OperationSupport.h"
Expand All @@ -45,7 +43,7 @@ namespace {

constexpr std::string_view kMangledDialect = "stable_mosaic.";
constexpr StringRef kVersionAttrName = "stable_mosaic.version";
constexpr int kVersion = 4;
constexpr int kVersion = 3;

StringRef mangle(StringRef name, std::string* storage) {
storage->clear();
Expand Down Expand Up @@ -88,37 +86,21 @@ LogicalResult enqueue_dma_rule(Operation* op, int version) {

LogicalResult semaphore_signal_rule(Operation* op, int version) {
// Added AttrSizedOperandSegments and core_id in version 2.
// Added subcore_id in version 4.
if (version < 2) {
if (op->getNumOperands() == 2) { // Local signal.
op->setAttr(
OpTrait::AttrSizedOperandSegments<
EnqueueDMAOp>::getOperandSegmentSizeAttr(),
mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 0, 0, 0}));
op->setAttr(OpTrait::AttrSizedOperandSegments<
EnqueueDMAOp>::getOperandSegmentSizeAttr(),
mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 0, 0}));
} else if (op->getNumOperands() == 3) { // Remote signal.
op->setAttr(
OpTrait::AttrSizedOperandSegments<
EnqueueDMAOp>::getOperandSegmentSizeAttr(),
mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 1, 0, 0}));
}
return op->emitError("Unexpected operand count in tpu.semaphore_signal");
} else if (version < 4) {
ArrayRef<int32_t> operand_segment_sizes =
op->getAttrOfType<DenseI32ArrayAttr>(
OpTrait::AttrSizedOperandSegments<
SemaphoreSignalOp>::getOperandSegmentSizeAttr());
if (operand_segment_sizes.size() != 4) {
return op->emitError(absl::StrFormat(
"Expected operand count to be 4 in tpu.semaphore_signal. Got %d",
operand_segment_sizes.size()));
// Hardcoding that one optional value is device_id, not core_id. This
// could misinterpret sem_signals where core_id is specified, but
// device_id isn't.
op->setAttr(OpTrait::AttrSizedOperandSegments<
EnqueueDMAOp>::getOperandSegmentSizeAttr(),
mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 1, 0}));
} else {
return op->emitError("Unexpected operand count in tpu.semaphore_signal");
}
SmallVector<int32_t, 5> new_operand_segment_sizes(
operand_segment_sizes.begin(), operand_segment_sizes.end());
new_operand_segment_sizes.push_back(0);
op->setAttr(OpTrait::AttrSizedOperandSegments<
EnqueueDMAOp>::getOperandSegmentSizeAttr(),
mlir::DenseI32ArrayAttr::get(op->getContext(),
new_operand_segment_sizes));
}
return success();
}
Expand Down

0 comments on commit e72b449

Please sign in to comment.