Skip to content

Commit

Permalink
tenstorrent#12065: Add optional tensor qid for tt::numpy::full and tt…
Browse files Browse the repository at this point in the history
…nn.full (tenstorrent#12289)
  • Loading branch information
KalaivaniMCW authored Sep 9, 2024
1 parent ba4df4b commit 3945004
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 34 deletions.
33 changes: 33 additions & 0 deletions tests/ttnn/unit_tests/operations/test_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,39 @@ def test_full(device, input_shape, fill_value):
assert torch.allclose(torch_tensor, tensor)


@pytest.mark.parametrize(
"input_shape",
[
[32, 32],
[5, 96, 64],
],
)
@pytest.mark.parametrize(
"fill_value",
[-5.25, 0, 2.5, 9],
)
@pytest.mark.parametrize(
"layout",
[ttnn.Layout.ROW_MAJOR, ttnn.Layout.TILE],
)
def test_full_with_opt_tensor(device, input_shape, layout, fill_value):
torch_tensor = torch.full(input_shape, dtype=torch.bfloat16, fill_value=fill_value)
opt_tensor = torch.ones(input_shape, dtype=torch.bfloat16)
opt_tensor = ttnn.from_torch(
opt_tensor, ttnn.bfloat16, layout=layout, device=device, memory_config=ttnn.L1_MEMORY_CONFIG
)

cq_id = 0
pages_before = ttnn._ttnn.reports.get_buffer_pages()
ttnn.full(input_shape, device=device, fill_value=fill_value, optional_tensor=opt_tensor, queue_id=cq_id)
assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages())
assert ttnn.is_tensor_storage_on_device(opt_tensor)
opt_tensor = ttnn.to_torch(opt_tensor)

assert_with_pcc(torch_tensor, opt_tensor, 0.9999)
assert torch.allclose(torch_tensor, opt_tensor)


@pytest.mark.parametrize(
"start",
[4, 8, 16, 32],
Expand Down
20 changes: 14 additions & 6 deletions ttnn/cpp/pybind11/operations/creation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,31 +35,39 @@ void bind_full_operation(py::module& module, const creation_operation_t& operati
const std::optional<DataType>& dtype,
const std::optional<Layout>& layout,
const std::optional<std::reference_wrapper<Device>>& device,
const std::optional<MemoryConfig>& memory_config) -> ttnn::Tensor {
return self(ttnn::Shape{tt::tt_metal::Shape{shape}}, fill_value, dtype, layout, device, memory_config);
const std::optional<MemoryConfig>& memory_config,
std::optional<ttnn::Tensor> &optional_output_tensor,
uint8_t queue_id) -> ttnn::Tensor {
return self(queue_id, ttnn::Shape{tt::tt_metal::Shape{shape}}, fill_value, dtype, layout, device, memory_config, optional_output_tensor);
},
py::arg("shape"),
py::arg("fill_value"),
py::arg("dtype") = std::nullopt,
py::arg("layout") = std::nullopt,
py::arg("device") = std::nullopt,
py::arg("memory_config") = std::nullopt},
py::arg("memory_config") = std::nullopt,
py::arg("optional_tensor") = std::nullopt,
py::arg("queue_id") = ttnn::DefaultQueueId},
ttnn::pybind_overload_t{
[](const creation_operation_t& self,
const std::vector<uint32_t>& shape,
const int fill_value,
const std::optional<DataType>& dtype,
const std::optional<Layout>& layout,
const std::optional<std::reference_wrapper<Device>>& device,
const std::optional<MemoryConfig>& memory_config) -> ttnn::Tensor {
return self(ttnn::Shape{tt::tt_metal::Shape{shape}}, fill_value, dtype, layout, device, memory_config);
const std::optional<MemoryConfig>& memory_config,
std::optional<ttnn::Tensor> &optional_output_tensor,
uint8_t queue_id) -> ttnn::Tensor {
return self(queue_id, ttnn::Shape{tt::tt_metal::Shape{shape}}, fill_value, dtype, layout, device, memory_config, optional_output_tensor);
},
py::arg("shape"),
py::arg("fill_value"),
py::arg("dtype") = std::nullopt,
py::arg("layout") = std::nullopt,
py::arg("device") = std::nullopt,
py::arg("memory_config") = std::nullopt});
py::arg("memory_config") = std::nullopt,
py::arg("optional_tensor") = std::nullopt,
py::arg("queue_id") = ttnn::DefaultQueueId});
}

template <typename creation_operation_t>
Expand Down
70 changes: 54 additions & 16 deletions ttnn/cpp/ttnn/deprecated/tt_numpy/functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <ttnn/tensor/tensor_utils.hpp>
#include <ttnn/tensor/types.hpp>
#include <ttnn/tensor/tensor_impl.hpp>
#include "ttnn/cpp/ttnn/common/constants.hpp"

namespace tt {

Expand Down Expand Up @@ -49,12 +50,14 @@ constexpr static DataType get_data_type() {

template <typename T>
static Tensor full(
uint8_t queue_id,
const Shape& shape,
T value,
const Layout layout = Layout::ROW_MAJOR,
Device* device = nullptr,
const MemoryConfig& output_mem_config = MemoryConfig{
.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}) {
.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED},
std::optional<Tensor> optional_output_tensor = std::nullopt) {
if (layout == Layout::TILE) {
if (shape.rank() < 2) {
TT_THROW("TILE layout requires rank >= 2");
Expand All @@ -69,48 +72,83 @@ static Tensor full(
tt::constants::TILE_HEIGHT);
}

constexpr DataType data_type = detail::get_data_type<T>();
auto owned_buffer = tt_metal::owned_buffer::create<T>(tt_metal::compute_volume(shape));
std::fill(std::begin(owned_buffer), std::end(owned_buffer), value);
auto output = Tensor(OwnedStorage{owned_buffer}, shape, data_type, layout);
if (device != nullptr) {
output = output.to(device, output_mem_config);
}
return output;
constexpr DataType data_type = detail::get_data_type<T>();
auto owned_buffer = tt_metal::owned_buffer::create<T>(tt_metal::compute_volume(shape));
std::fill(std::begin(owned_buffer), std::end(owned_buffer), value);

if(!optional_output_tensor.has_value()){
auto output = Tensor(OwnedStorage{owned_buffer}, shape, data_type, layout);
if (device != nullptr) {
output = output.to(device, output_mem_config);
}
return output;
}
else {
auto device_buffer = std::get<DeviceStorage>(optional_output_tensor.value().tensor_attributes->storage).get_buffer();
bool using_fast_dispatch = (std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr);

if (using_fast_dispatch && device != nullptr) {
auto& cmd_queue = device->command_queue(queue_id);
if (CommandQueue::default_mode() == CommandQueue::CommandQueueMode::ASYNC) {
tt::tt_metal::EnqueueWriteBuffer(cmd_queue, device_buffer, owned_buffer.get_ptr(), false);
} else {
tt::tt_metal::EnqueueWriteBuffer(cmd_queue, device_buffer, owned_buffer.data(), false);
}
} else {
auto uint32_data = tt::tt_metal::tensor_impl::pack_vec_into_uint32_vec<T>(owned_buffer);
tt::tt_metal::detail::WriteToBuffer(*device_buffer, uint32_data);
}

return optional_output_tensor.value();
}
}

} // namespace detail

template <typename T>
static Tensor full(
static Tensor full_impl(
uint8_t queue_id,
const Shape& shape,
const T value,
const DataType data_type,
const Layout layout = Layout::ROW_MAJOR,
Device* device = nullptr,
const MemoryConfig& output_mem_config = MemoryConfig{
.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}) {
.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED},
std::optional<Tensor> optional_output_tensor = std::nullopt) {
switch (data_type) {
case DataType::UINT8: {
return detail::full<uint8_t>(shape, uint8_t(value), layout, device, output_mem_config);
return detail::full<uint8_t>(queue_id, shape, uint8_t(value), layout, device, output_mem_config, optional_output_tensor);
}
case DataType::UINT16: {
return detail::full<uint16_t>(shape, uint16_t(value), layout, device, output_mem_config);
return detail::full<uint16_t>(queue_id, shape, uint16_t(value), layout, device, output_mem_config, optional_output_tensor);
}
case DataType::UINT32: {
return detail::full<uint32_t>(shape, uint32_t(value), layout, device, output_mem_config);
return detail::full<uint32_t>(queue_id, shape, uint32_t(value), layout, device, output_mem_config, optional_output_tensor);
}
case DataType::FLOAT32: {
return detail::full<float>(shape, float(value), layout, device, output_mem_config);
return detail::full<float>(queue_id, shape, float(value), layout, device, output_mem_config, optional_output_tensor);
}
case DataType::BFLOAT16: {
return detail::full<bfloat16>(
shape, bfloat16(static_cast<float>(value)), layout, device, output_mem_config);
queue_id, shape, bfloat16(static_cast<float>(value)), layout, device, output_mem_config, optional_output_tensor);
}
default: TT_THROW("Unsupported DataType!");
}
}

template <typename T>
static Tensor full(
const Shape& shape,
const T value,
const DataType data_type,
const Layout layout = Layout::ROW_MAJOR,
Device* device = nullptr,
const MemoryConfig& output_mem_config = MemoryConfig{
.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}) {
return full_impl(ttnn::DefaultQueueId, shape, value, data_type, layout, device, output_mem_config, std::nullopt);
}

static Tensor zeros(
const Shape& shape,
const DataType data_type = DataType::BFLOAT16,
Expand Down
72 changes: 60 additions & 12 deletions ttnn/cpp/ttnn/operations/creation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "ttnn/core.hpp"
#include "ttnn/decorators.hpp"
#include "ttnn/types.hpp"
#include "ttnn/common/constants.hpp"

namespace ttnn {
namespace operations {
Expand Down Expand Up @@ -55,21 +56,42 @@ Tensor create_scalar(T scalar, DataType data_type, Layout layout, Device* device
}

template <typename T>
inline ttnn::Tensor full(
inline ttnn::Tensor full_impl(
uint8_t queue_id,
const ttnn::Shape& shape,
const T fill_value,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
const std::optional<std::reference_wrapper<Device>>& device_arg = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt) {
Device* device = device_arg.has_value() ? &(device_arg.value().get()) : nullptr;
return tt::numpy::full(
shape.value,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt) {
Device* device = optional_output_tensor.has_value() ? optional_output_tensor.value().device() : device_arg.has_value() ? &(device_arg.value().get()) : nullptr;
Layout layout_value = optional_output_tensor.has_value() ? optional_output_tensor.value().get_layout() : layout.value_or(ttnn::ROW_MAJOR_LAYOUT);
DataType dtype_value = optional_output_tensor.has_value() ? optional_output_tensor.value().get_dtype() : dtype.value_or(ttnn::bfloat16);
tt::tt_metal::Shape shape_value = optional_output_tensor.has_value() ? optional_output_tensor.value().get_legacy_shape() : shape.value;
MemoryConfig mem_cfg = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG);
return tt::numpy::full_impl(
queue_id,
shape_value,
fill_value,
dtype.value_or(ttnn::bfloat16),
layout.value_or(ttnn::ROW_MAJOR_LAYOUT),
dtype_value,
layout_value,
device,
memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG));
mem_cfg,
optional_output_tensor);
}

template <typename T>
inline ttnn::Tensor full(
const ttnn::Shape& shape,
const T fill_value,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
const std::optional<std::reference_wrapper<Device>>& device_arg = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt,
uint8_t queue_id = ttnn::DefaultQueueId) {
return full_impl(queue_id, shape, fill_value, dtype, layout, device_arg, memory_config, optional_output_tensor);
}

namespace detail {
Expand Down Expand Up @@ -159,23 +181,49 @@ inline constexpr EmptyLike empty_like{};

struct Full {
static ttnn::Tensor invoke(
uint8_t queue_id,
const ttnn::Shape& shape,
const float fill_value,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
const std::optional<std::reference_wrapper<Device>>& device = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt) {
return full(shape, fill_value, dtype, layout, device, memory_config);
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt) {
return full_impl(queue_id, shape, fill_value, dtype, layout, device, memory_config, optional_output_tensor);
}

static ttnn::Tensor invoke(
uint8_t queue_id,
const ttnn::Shape& shape,
const int fill_value,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
const std::optional<std::reference_wrapper<Device>>& device = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt) {
return full(shape, fill_value, dtype, layout, device, memory_config);
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt) {
return full_impl(queue_id, shape, fill_value, dtype, layout, device, memory_config, optional_output_tensor);
}

static ttnn::Tensor invoke(
const ttnn::Shape& shape,
const float fill_value,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
const std::optional<std::reference_wrapper<Device>>& device = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt) {
return full_impl(ttnn::DefaultQueueId, shape, fill_value, dtype, layout, device, memory_config, optional_output_tensor);
}

static ttnn::Tensor invoke(
const ttnn::Shape& shape,
const int fill_value,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
const std::optional<std::reference_wrapper<Device>>& device = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt) {
return full_impl(ttnn::DefaultQueueId, shape, fill_value, dtype, layout, device, memory_config, optional_output_tensor);
}
};

Expand Down

0 comments on commit 3945004

Please sign in to comment.