diff --git a/tests/ttnn/unit_tests/operations/test_creation.py b/tests/ttnn/unit_tests/operations/test_creation.py index 1aff7f0e63e..eba32c081e7 100644 --- a/tests/ttnn/unit_tests/operations/test_creation.py +++ b/tests/ttnn/unit_tests/operations/test_creation.py @@ -138,6 +138,39 @@ def test_full(device, input_shape, fill_value): assert torch.allclose(torch_tensor, tensor) +@pytest.mark.parametrize( + "input_shape", + [ + [32, 32], + [5, 96, 64], + ], +) +@pytest.mark.parametrize( + "fill_value", + [-5.25, 0, 2.5, 9], +) +@pytest.mark.parametrize( + "layout", + [ttnn.Layout.ROW_MAJOR, ttnn.Layout.TILE], +) +def test_full_with_opt_tensor(device, input_shape, layout, fill_value): + torch_tensor = torch.full(input_shape, dtype=torch.bfloat16, fill_value=fill_value) + opt_tensor = torch.ones(input_shape, dtype=torch.bfloat16) + opt_tensor = ttnn.from_torch( + opt_tensor, ttnn.bfloat16, layout=layout, device=device, memory_config=ttnn.L1_MEMORY_CONFIG + ) + + cq_id = 0 + pages_before = ttnn._ttnn.reports.get_buffer_pages() + ttnn.full(input_shape, device=device, fill_value=fill_value, optional_tensor=opt_tensor, queue_id=cq_id) + assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages()) + assert ttnn.is_tensor_storage_on_device(opt_tensor) + opt_tensor = ttnn.to_torch(opt_tensor) + + assert_with_pcc(torch_tensor, opt_tensor, 0.9999) + assert torch.allclose(torch_tensor, opt_tensor) + + @pytest.mark.parametrize( "start", [4, 8, 16, 32], diff --git a/ttnn/cpp/pybind11/operations/creation.hpp b/ttnn/cpp/pybind11/operations/creation.hpp index 28679099f44..6feabd48a74 100644 --- a/ttnn/cpp/pybind11/operations/creation.hpp +++ b/ttnn/cpp/pybind11/operations/creation.hpp @@ -35,15 +35,19 @@ void bind_full_operation(py::module& module, const creation_operation_t& operati const std::optional& dtype, const std::optional& layout, const std::optional>& device, - const std::optional& memory_config) -> ttnn::Tensor { - return self(ttnn::Shape{tt::tt_metal::Shape{shape}}, fill_value, dtype, layout, device, memory_config); + const std::optional& memory_config, + std::optional &optional_output_tensor, + uint8_t queue_id) -> ttnn::Tensor { + return self(queue_id, ttnn::Shape{tt::tt_metal::Shape{shape}}, fill_value, dtype, layout, device, memory_config, optional_output_tensor); }, py::arg("shape"), py::arg("fill_value"), py::arg("dtype") = std::nullopt, py::arg("layout") = std::nullopt, py::arg("device") = std::nullopt, - py::arg("memory_config") = std::nullopt}, + py::arg("memory_config") = std::nullopt, + py::arg("optional_tensor") = std::nullopt, + py::arg("queue_id") = ttnn::DefaultQueueId}, ttnn::pybind_overload_t{ [](const creation_operation_t& self, const std::vector& shape, @@ -51,15 +55,19 @@ void bind_full_operation(py::module& module, const creation_operation_t& operati const std::optional& dtype, const std::optional& layout, const std::optional>& device, - const std::optional& memory_config) -> ttnn::Tensor { - return self(ttnn::Shape{tt::tt_metal::Shape{shape}}, fill_value, dtype, layout, device, memory_config); + const std::optional& memory_config, + std::optional &optional_output_tensor, + uint8_t queue_id) -> ttnn::Tensor { + return self(queue_id, ttnn::Shape{tt::tt_metal::Shape{shape}}, fill_value, dtype, layout, device, memory_config, optional_output_tensor); }, py::arg("shape"), py::arg("fill_value"), py::arg("dtype") = std::nullopt, py::arg("layout") = std::nullopt, py::arg("device") = std::nullopt, - py::arg("memory_config") = std::nullopt}); + py::arg("memory_config") = std::nullopt, + py::arg("optional_tensor") = std::nullopt, + py::arg("queue_id") = ttnn::DefaultQueueId}); } template diff --git a/ttnn/cpp/ttnn/deprecated/tt_numpy/functions.hpp b/ttnn/cpp/ttnn/deprecated/tt_numpy/functions.hpp index a95a2dac252..3ba2b7f6bf2 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_numpy/functions.hpp +++ b/ttnn/cpp/ttnn/deprecated/tt_numpy/functions.hpp @@ -13,6 +13,7 @@ #include #include #include +#include "ttnn/cpp/ttnn/common/constants.hpp" namespace tt { @@ -49,12 +50,14 @@ constexpr static DataType get_data_type() { template static Tensor full( + uint8_t queue_id, const Shape& shape, T value, const Layout layout = Layout::ROW_MAJOR, Device* device = nullptr, const MemoryConfig& output_mem_config = MemoryConfig{ - .memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}) { + .memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}, + std::optional optional_output_tensor = std::nullopt) { if (layout == Layout::TILE) { if (shape.rank() < 2) { TT_THROW("TILE layout requires rank >= 2"); @@ -69,48 +72,83 @@ static Tensor full( tt::constants::TILE_HEIGHT); } - constexpr DataType data_type = detail::get_data_type(); - auto owned_buffer = tt_metal::owned_buffer::create(tt_metal::compute_volume(shape)); - std::fill(std::begin(owned_buffer), std::end(owned_buffer), value); - auto output = Tensor(OwnedStorage{owned_buffer}, shape, data_type, layout); - if (device != nullptr) { - output = output.to(device, output_mem_config); - } - return output; + constexpr DataType data_type = detail::get_data_type(); + auto owned_buffer = tt_metal::owned_buffer::create(tt_metal::compute_volume(shape)); + std::fill(std::begin(owned_buffer), std::end(owned_buffer), value); + + if(!optional_output_tensor.has_value()){ + auto output = Tensor(OwnedStorage{owned_buffer}, shape, data_type, layout); + if (device != nullptr) { + output = output.to(device, output_mem_config); + } + return output; + } + else { + auto device_buffer = std::get(optional_output_tensor.value().tensor_attributes->storage).get_buffer(); + bool using_fast_dispatch = (std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr); + + if (using_fast_dispatch && device != nullptr) { + auto& cmd_queue = device->command_queue(queue_id); + if (CommandQueue::default_mode() == CommandQueue::CommandQueueMode::ASYNC) { + tt::tt_metal::EnqueueWriteBuffer(cmd_queue, device_buffer, owned_buffer.get_ptr(), false); + } else { + tt::tt_metal::EnqueueWriteBuffer(cmd_queue, device_buffer, owned_buffer.data(), false); + } + } else { + auto uint32_data = tt::tt_metal::tensor_impl::pack_vec_into_uint32_vec(owned_buffer); + tt::tt_metal::detail::WriteToBuffer(*device_buffer, uint32_data); + } + + return optional_output_tensor.value(); + } } } // namespace detail template -static Tensor full( +static Tensor full_impl( + uint8_t queue_id, const Shape& shape, const T value, const DataType data_type, const Layout layout = Layout::ROW_MAJOR, Device* device = nullptr, const MemoryConfig& output_mem_config = MemoryConfig{ - .memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}) { + .memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}, + std::optional optional_output_tensor = std::nullopt) { switch (data_type) { case DataType::UINT8: { - return detail::full(shape, uint8_t(value), layout, device, output_mem_config); + return detail::full(queue_id, shape, uint8_t(value), layout, device, output_mem_config, optional_output_tensor); } case DataType::UINT16: { - return detail::full(shape, uint16_t(value), layout, device, output_mem_config); + return detail::full(queue_id, shape, uint16_t(value), layout, device, output_mem_config, optional_output_tensor); } case DataType::UINT32: { - return detail::full(shape, uint32_t(value), layout, device, output_mem_config); + return detail::full(queue_id, shape, uint32_t(value), layout, device, output_mem_config, optional_output_tensor); } case DataType::FLOAT32: { - return detail::full(shape, float(value), layout, device, output_mem_config); + return detail::full(queue_id, shape, float(value), layout, device, output_mem_config, optional_output_tensor); } case DataType::BFLOAT16: { return detail::full( - shape, bfloat16(static_cast(value)), layout, device, output_mem_config); + queue_id, shape, bfloat16(static_cast(value)), layout, device, output_mem_config, optional_output_tensor); } default: TT_THROW("Unsupported DataType!"); } } +template +static Tensor full( + const Shape& shape, + const T value, + const DataType data_type, + const Layout layout = Layout::ROW_MAJOR, + Device* device = nullptr, + const MemoryConfig& output_mem_config = MemoryConfig{ + .memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}) { + return full_impl(ttnn::DefaultQueueId, shape, value, data_type, layout, device, output_mem_config, std::nullopt); +} + static Tensor zeros( const Shape& shape, const DataType data_type = DataType::BFLOAT16, diff --git a/ttnn/cpp/ttnn/operations/creation.hpp b/ttnn/cpp/ttnn/operations/creation.hpp index 51873b35080..25f98c42c67 100644 --- a/ttnn/cpp/ttnn/operations/creation.hpp +++ b/ttnn/cpp/ttnn/operations/creation.hpp @@ -12,6 +12,7 @@ #include "ttnn/core.hpp" #include "ttnn/decorators.hpp" #include "ttnn/types.hpp" +#include "ttnn/common/constants.hpp" namespace ttnn { namespace operations { @@ -55,21 +56,42 @@ Tensor create_scalar(T scalar, DataType data_type, Layout layout, Device* device } template -inline ttnn::Tensor full( +inline ttnn::Tensor full_impl( + uint8_t queue_id, const ttnn::Shape& shape, const T fill_value, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, const std::optional>& device_arg = std::nullopt, - const std::optional& memory_config = std::nullopt) { - Device* device = device_arg.has_value() ? &(device_arg.value().get()) : nullptr; - return tt::numpy::full( - shape.value, + const std::optional& memory_config = std::nullopt, + std::optional optional_output_tensor = std::nullopt) { + Device* device = optional_output_tensor.has_value() ? optional_output_tensor.value().device() : device_arg.has_value() ? &(device_arg.value().get()) : nullptr; + Layout layout_value = optional_output_tensor.has_value() ? optional_output_tensor.value().get_layout() : layout.value_or(ttnn::ROW_MAJOR_LAYOUT); + DataType dtype_value = optional_output_tensor.has_value() ? optional_output_tensor.value().get_dtype() : dtype.value_or(ttnn::bfloat16); + tt::tt_metal::Shape shape_value = optional_output_tensor.has_value() ? optional_output_tensor.value().get_legacy_shape() : shape.value; + MemoryConfig mem_cfg = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG); + return tt::numpy::full_impl( + queue_id, + shape_value, fill_value, - dtype.value_or(ttnn::bfloat16), - layout.value_or(ttnn::ROW_MAJOR_LAYOUT), + dtype_value, + layout_value, device, - memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG)); + mem_cfg, + optional_output_tensor); +} + +template +inline ttnn::Tensor full( + const ttnn::Shape& shape, + const T fill_value, + const std::optional& dtype = std::nullopt, + const std::optional& layout = std::nullopt, + const std::optional>& device_arg = std::nullopt, + const std::optional& memory_config = std::nullopt, + std::optional optional_output_tensor = std::nullopt, + uint8_t queue_id = ttnn::DefaultQueueId) { + return full_impl(queue_id, shape, fill_value, dtype, layout, device_arg, memory_config, optional_output_tensor); } namespace detail { @@ -159,23 +181,49 @@ inline constexpr EmptyLike empty_like{}; struct Full { static ttnn::Tensor invoke( + uint8_t queue_id, const ttnn::Shape& shape, const float fill_value, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, const std::optional>& device = std::nullopt, - const std::optional& memory_config = std::nullopt) { - return full(shape, fill_value, dtype, layout, device, memory_config); + const std::optional& memory_config = std::nullopt, + std::optional optional_output_tensor = std::nullopt) { + return full_impl(queue_id, shape, fill_value, dtype, layout, device, memory_config, optional_output_tensor); } static ttnn::Tensor invoke( + uint8_t queue_id, const ttnn::Shape& shape, const int fill_value, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, const std::optional>& device = std::nullopt, - const std::optional& memory_config = std::nullopt) { - return full(shape, fill_value, dtype, layout, device, memory_config); + const std::optional& memory_config = std::nullopt, + std::optional optional_output_tensor = std::nullopt) { + return full_impl(queue_id, shape, fill_value, dtype, layout, device, memory_config, optional_output_tensor); + } + + static ttnn::Tensor invoke( + const ttnn::Shape& shape, + const float fill_value, + const std::optional& dtype = std::nullopt, + const std::optional& layout = std::nullopt, + const std::optional>& device = std::nullopt, + const std::optional& memory_config = std::nullopt, + std::optional optional_output_tensor = std::nullopt) { + return full_impl(ttnn::DefaultQueueId, shape, fill_value, dtype, layout, device, memory_config, optional_output_tensor); + } + + static ttnn::Tensor invoke( + const ttnn::Shape& shape, + const int fill_value, + const std::optional& dtype = std::nullopt, + const std::optional& layout = std::nullopt, + const std::optional>& device = std::nullopt, + const std::optional& memory_config = std::nullopt, + std::optional optional_output_tensor = std::nullopt) { + return full_impl(ttnn::DefaultQueueId, shape, fill_value, dtype, layout, device, memory_config, optional_output_tensor); } };