Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically generate an overload w/o QueueId #17640

Merged
merged 20 commits into from
Feb 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 54 additions & 31 deletions ttnn/cpp/ttnn/decorators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,15 @@ concept PrimitiveOperationConcept = device_operation::DeviceOperationConcept<ope
template <typename operation_t>
concept CompositeOperationConcept = !PrimitiveOperationConcept<operation_t>;

template <typename Op, typename... Args>
concept HasInvoke = requires {
{ Op::invoke(std::declval<Args>()...) };
};

template <typename T, typename... Args>
concept FirstArgIs =
sizeof...(Args) > 0 && std::same_as<std::decay_t<std::tuple_element_t<0, std::tuple<Args&&...>>>, T>;

template <reflect::fixed_string cpp_fully_qualified_name, typename operation_t, bool auto_launch_op>
struct registered_operation_t {
static constexpr auto is_primitive = PrimitiveOperationConcept<operation_t>;
Expand All @@ -216,6 +225,45 @@ struct registered_operation_t {
return detail::python_fully_qualified_name(std::string{cpp_fully_qualified_name});
}

// --- operator() Overloads ---

// (1) Overload when the first argument is a QueueId.
template <typename First, typename... Rest>
requires std::same_as<std::decay_t<First>, QueueId>
auto operator()(First&& first, Rest&&... rest) const {
return traced_invoke(std::forward<First>(first), std::forward<Rest>(rest)...);
}

// (2a) Overload when no QueueId is provided AND the operation is invocable without a QueueId.
template <typename... Args>
requires(sizeof...(Args) == 0 || (!FirstArgIs<QueueId, Args...> && HasInvoke<operation_t, Args && ...>))
auto operator()(Args&&... args) const {
return traced_invoke(std::forward<Args>(args)...);
}

// (2b) Overload when no QueueId is provided but the operation is NOT invocable without a QueueId,
// so we inject DefaultQueueId.
template <typename... Args>
requires(
sizeof...(Args) == 0 || (!FirstArgIs<QueueId, Args...> && !HasInvoke<operation_t, Args && ...> &&
HasInvoke<operation_t, QueueId, Args && ...>))
auto operator()(Args&&... args) const {
return traced_invoke(DefaultQueueId, std::forward<Args>(args)...);
}

private:
template <typename... args_t>
auto traced_invoke(args_t&&... args) const {
tt::log_debug(tt::LogOp, "Started C++ ttnn operation: {}", std::string_view{cpp_fully_qualified_name});
tt::tt_metal::GraphTracker::instance().track_function_start(cpp_fully_qualified_name, args...);

auto output = invoke(std::forward<args_t>(args)...);

tt::tt_metal::GraphTracker::instance().track_function_end(output);
tt::log_debug(tt::LogOp, "Finished C++ ttnn operation: {}", std::string_view{cpp_fully_qualified_name});
return output;
}

template <typename... args_t>
requires PrimitiveOperationConcept<operation_t>
auto invoke(QueueId queue_id, args_t&&... args) const {
Expand All @@ -234,6 +282,12 @@ struct registered_operation_t {
return invoke(DefaultQueueId, std::forward<args_t>(args)...);
}

template <typename... args_t>
requires(CompositeOperationConcept<operation_t>)
auto invoke(args_t&&... args) const {
return invoke_composite(std::forward<args_t>(args)...);
}

template <typename... args_t>
requires(not auto_launch_op)
auto invoke_composite(args_t&&... args) const {
Expand Down Expand Up @@ -300,30 +354,6 @@ struct registered_operation_t {
"Tensor(s).");
}
}

template <typename... args_t>
requires(CompositeOperationConcept<operation_t>)
auto invoke(args_t&&... args) const {
return invoke_composite(std::forward<args_t>(args)...);
}

template <typename... args_t>
auto operator()(args_t&&... args) const {
tt::log_debug(tt::LogOp, "Started C++ ttnn operation: {}", std::string_view{cpp_fully_qualified_name});
tt::tt_metal::GraphTracker::instance().track_function_start(cpp_fully_qualified_name, args...);
ayerofieiev-tt marked this conversation as resolved.
Show resolved Hide resolved
auto output = invoke(std::forward<args_t>(args)...);

// Should every output tensor be tracked?
/*
if (GraphTracker::instance().is_enabled()) {
output = tt::stl::reflection::transform_object_of_type<Tensor>(tt::tt_metal::set_tensor_id, output);
}
*/

tt::tt_metal::GraphTracker::instance().track_function_end(output);
tt::log_debug(tt::LogOp, "Finished C++ ttnn operation: {}", std::string_view{cpp_fully_qualified_name});
return output;
}
};

template <reflect::fixed_string cpp_fully_qualified_name>
Expand Down Expand Up @@ -393,13 +423,6 @@ constexpr auto register_operation_with_auto_launch_op() {
return register_operation_impl<cpp_fully_qualified_name, operation_t, true>();
}

namespace detail {
template <auto lambda_t>
struct lambda_operation_t {
static auto invoke(auto&&... args) { return lambda_t(std::forward<decltype(args)>(args)...); }
};
} // namespace detail

} // namespace decorators

using ttnn::decorators::register_operation;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ Tensor all_gather(
if (input_tensor.get_dtype() != DataType::BFLOAT16 && input_tensor.get_dtype() != DataType::FLOAT32) {
input_tensor = ttnn::typecast(input_tensor, DataType::BFLOAT16);
}
input_tensor = ttnn::pad(ttnn::DefaultQueueId, input_tensor, padding, 0, false, std::nullopt);
input_tensor = ttnn::pad(input_tensor, padding, 0, false, std::nullopt);
if (original_dtype != input_tensor.get_dtype()) {
input_tensor = ttnn::typecast(input_tensor, original_dtype);
}
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ Tensor to_layout_impl(
{0, 0},
{0, padded_output_shape[2] - output_shape[2]},
{0, padded_output_shape[3] - output_shape[3]}};
tensor = ttnn::pad(ttnn::DefaultQueueId, tensor, padding, 0, true, std::nullopt);
tensor = ttnn::pad(tensor, padding, 0, true, std::nullopt);
return ttnn::tilize(tensor, output_memory_config, dtype, use_multicore_tilize);
} else {
PadValue pad_value_variant;
Expand Down
9 changes: 0 additions & 9 deletions ttnn/cpp/ttnn/operations/data_movement/concat/concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,15 +329,6 @@ ttnn::Tensor ConcatOperation::invoke(
return res;
}

ttnn::Tensor ConcatOperation::invoke(
const std::vector<ttnn::Tensor>& input_tensors,
int dim,
const std::optional<MemoryConfig>& memory_config,
const std::optional<ttnn::Tensor>& optional_output_tensor,
unsigned int groups) {
return invoke(DefaultQueueId, input_tensors, dim, memory_config, std::move(optional_output_tensor), groups);
}

} // namespace data_movement
} // namespace operations
} // namespace ttnn
7 changes: 0 additions & 7 deletions ttnn/cpp/ttnn/operations/data_movement/concat/concat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,6 @@ struct ConcatOperation {
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<ttnn::Tensor>& optional_output_tensor = std::nullopt,
unsigned int groups = 1);

static ttnn::Tensor invoke(
const std::vector<ttnn::Tensor>& input_tensors,
int dim,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<ttnn::Tensor>& optional_output_tensor = std::nullopt,
unsigned int groups = 1);
};

} // namespace data_movement
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,4 @@ ttnn::Tensor FillPadOperation::invoke(
.at(0);
}

ttnn::Tensor FillPadOperation::invoke(
const ttnn::Tensor& input_tensor, float fill_value, const std::optional<ttnn::MemoryConfig>& memory_config_arg) {
return invoke(DefaultQueueId, input_tensor, fill_value, memory_config_arg);
}

} // namespace ttnn::operations::data_movement
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@ struct FillPadOperation {
const ttnn::Tensor& input_tensor,
float fill_value,
const std::optional<ttnn::MemoryConfig>& memory_config = std::nullopt);

static ttnn::Tensor invoke(
const ttnn::Tensor& input_tensor,
float fill_value,
const std::optional<ttnn::MemoryConfig>& memory_config = std::nullopt);
};

} // namespace data_movement
Expand Down
26 changes: 0 additions & 26 deletions ttnn/cpp/ttnn/operations/data_movement/fill_rm/fill_rm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,6 @@ ttnn::Tensor FillRMOperation::invoke(
.at(0);
}

ttnn::Tensor FillRMOperation::invoke(
uint32_t N,
uint32_t C,
uint32_t H,
uint32_t W,
uint32_t hFill,
uint32_t wFill,
const ttnn::Tensor& any,
float val_hi,
float val_lo,
const std::optional<ttnn::MemoryConfig>& memory_config_arg) {
return invoke(DefaultQueueId, N, C, H, W, hFill, wFill, any, val_hi, val_lo, memory_config_arg);
}

ttnn::Tensor FillOnesRMOperation::invoke(
QueueId queue_id,
uint32_t N,
Expand All @@ -60,16 +46,4 @@ ttnn::Tensor FillOnesRMOperation::invoke(
.at(0);
}

ttnn::Tensor FillOnesRMOperation::invoke(
uint32_t N,
uint32_t C,
uint32_t H,
uint32_t W,
uint32_t hFill,
uint32_t wFill,
const ttnn::Tensor& any,
const std::optional<ttnn::MemoryConfig>& memory_config_arg) {
return invoke(DefaultQueueId, N, C, H, W, hFill, wFill, any, memory_config_arg);
}

} // namespace ttnn::operations::data_movement
22 changes: 0 additions & 22 deletions ttnn/cpp/ttnn/operations/data_movement/fill_rm/fill_rm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,6 @@ struct FillRMOperation {
float val_hi,
float val_lo,
const std::optional<ttnn::MemoryConfig>& memory_config = std::nullopt);

static ttnn::Tensor invoke(
uint32_t N,
uint32_t C,
uint32_t H,
uint32_t W,
uint32_t hFill,
uint32_t wFill,
const ttnn::Tensor& any,
float val_hi,
float val_lo,
const std::optional<ttnn::MemoryConfig>& memory_config = std::nullopt);
};

struct FillOnesRMOperation {
Expand All @@ -48,16 +36,6 @@ struct FillOnesRMOperation {
uint32_t wFill,
const ttnn::Tensor& any,
const std::optional<ttnn::MemoryConfig>& memory_config = std::nullopt);

static ttnn::Tensor invoke(
uint32_t N,
uint32_t C,
uint32_t H,
uint32_t W,
uint32_t hFill,
uint32_t wFill,
const ttnn::Tensor& any,
const std::optional<ttnn::MemoryConfig>& memory_config = std::nullopt);
};

} // namespace data_movement
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,4 @@ ttnn::Tensor IndexedFillOperation::invoke(
.at(0);
}

ttnn::Tensor IndexedFillOperation::invoke(
const ttnn::Tensor& batch_id,
const ttnn::Tensor& input_tensor_a,
const ttnn::Tensor& input_tensor_b,
const std::optional<ttnn::MemoryConfig>& memory_config,
int64_t dim) {
return invoke(DefaultQueueId, batch_id, input_tensor_a, input_tensor_b, memory_config, dim);
}

} // namespace ttnn::operations::data_movement
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,6 @@ struct IndexedFillOperation {
const ttnn::Tensor& input_tensor_b,
const std::optional<ttnn::MemoryConfig>& memory_config = std::nullopt,
int64_t dim = 0);

static ttnn::Tensor invoke(
const ttnn::Tensor& batch_id,
const ttnn::Tensor& input_tensor_a,
const ttnn::Tensor& input_tensor_b,
const std::optional<ttnn::MemoryConfig>& memory_config = std::nullopt,
int64_t dim = 0);
};

} // namespace data_movement
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,4 @@ std::vector<ttnn::Tensor> NonZeroIndicesOperation::invoke(
return operation::run_without_autoformat(NonZeroIndices{memory_config}, {input_tensor}, {}, {}, queue_id);
}

std::vector<ttnn::Tensor> NonZeroIndicesOperation::invoke(
const ttnn::Tensor& input_tensor, const std::optional<MemoryConfig>& memory_config_arg) {
return invoke(DefaultQueueId, input_tensor, memory_config_arg);
}

} // namespace ttnn::operations::data_movement
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@ namespace operations::data_movement {
struct NonZeroIndicesOperation {
static std::vector<ttnn::Tensor> invoke(
QueueId queue_id, const ttnn::Tensor& input_tensor, const std::optional<MemoryConfig>& memory_config);

static std::vector<ttnn::Tensor> invoke(
const ttnn::Tensor& input_tensor, const std::optional<MemoryConfig>& memory_config);
};

} // namespace operations::data_movement
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//
// SPDX-License-Identifier: Apache-2.0

#include "ttnn/common/queue_id.hpp"
#include "ttnn/common/constants.hpp"
#include "ttnn/run_operation.hpp"
#include "reshape.hpp"
#include <tt-metalium/constants.hpp>
Expand Down Expand Up @@ -104,30 +104,6 @@ ttnn::Tensor ReshapeOperation::invoke(
return invoke(queue_id, input_tensor, logical_output_shape, logical_output_shape, memory_config_arg);
}

ttnn::Tensor ReshapeOperation::invoke(
const ttnn::Tensor& input_tensor,
const ttnn::Shape& logical_shape,
const ttnn::Shape& padded_shape,
const std::optional<MemoryConfig>& memory_config) {
return invoke(DefaultQueueId, input_tensor, logical_shape, padded_shape, memory_config);
}

ttnn::Tensor ReshapeOperation::invoke(
const ttnn::Tensor& input_tensor,
const ttnn::Shape& logical_shape,
const std::optional<MemoryConfig>& memory_config) {
return invoke(input_tensor, logical_shape, logical_shape, memory_config);
}

ttnn::Tensor ReshapeOperation::invoke(
const ttnn::Tensor& input_tensor, const ttnn::Shape& logical_shape, const ttnn::Shape& padded_shape) {
return invoke(DefaultQueueId, input_tensor, logical_shape, padded_shape, std::nullopt);
}

ttnn::Tensor ReshapeOperation::invoke(const ttnn::Tensor& input_tensor, const ttnn::Shape& logical_shape) {
return invoke(input_tensor, logical_shape, logical_shape);
}

ttnn::Tensor ReshapeOperation::invoke(
QueueId queue_id,
const ttnn::Tensor& input_tensor,
Expand All @@ -136,15 +112,4 @@ ttnn::Tensor ReshapeOperation::invoke(
return invoke(queue_id, input_tensor, infer_dims_for_reshape(input_tensor, shape_vector), memory_config_arg);
}

ttnn::Tensor ReshapeOperation::invoke(
const ttnn::Tensor& input_tensor,
tt::stl::Span<const int32_t> shape_vector,
const std::optional<MemoryConfig>& memory_config_arg) {
return invoke(DefaultQueueId, input_tensor, shape_vector, memory_config_arg);
}

ttnn::Tensor ReshapeOperation::invoke(const ttnn::Tensor& input_tensor, tt::stl::Span<const int32_t> shape_vector) {
return invoke(input_tensor, shape_vector, std::nullopt);
}

} // namespace ttnn::operations::data_movement
Loading
Loading