Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tt-train] Enable tensor parallel for MNIST #17506

Merged
merged 15 commits into from
Feb 5, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Clean up tests. Add ColumnParallelLinear
  • Loading branch information
rfurko-tt committed Feb 4, 2025
commit ea21f7ac2d22a690cdca8e84c564ddc1e9763cf7
178 changes: 169 additions & 9 deletions tt-train/tests/modules/distributed/linear_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include "autograd/auto_context.hpp"
#include "core/distributed_mapping.hpp"
#include "core/tt_tensor_utils.hpp"
#include "init/cpu_initializers.hpp"

namespace {

Expand Down Expand Up @@ -156,10 +155,6 @@ TEST_F(N300TensorParallelLinearTest, RowParallelLinearHasBiasInputParallel) {
// (1, 1, out_features, in_features)
auto weight_xtensor = ttml::core::to_xtensor<float>(weight->get_value(), concat_composer);
auto bias_xtensor = ttml::core::to_xtensor<float>(bias->get_value(), identity_composer);

auto weight_xtensor_shape = weight_xtensor[0].shape();
auto test_data_shape = test_data.shape();

auto expected_output = xt::linalg::dot(test_data, xt::transpose(weight_xtensor[0], {0, 1, 3, 2}));
if (has_bias) {
expected_output += bias_xtensor[0];
Expand Down Expand Up @@ -197,12 +192,177 @@ TEST_F(N300TensorParallelLinearTest, RowParallelLinearNoBiasInputParallel) {
ttml::core::MeshToXTensorVariant<float> concat_composer = ttml::core::ConcatMeshToXTensor<float>(mesh_shape, 3U);
// (1, 1, out_features, in_features)
auto weight_xtensor = ttml::core::to_xtensor<float>(weight->get_value(), concat_composer);

auto weight_xtensor_shape = weight_xtensor[0].shape();
auto test_data_shape = test_data.shape();

auto expected_output = xt::linalg::dot(test_data, xt::transpose(weight_xtensor[0], {0, 1, 3, 2}));

EXPECT_TRUE(xt::allclose(expected_output, output_xtensor[0], /* rtol */ 1e-3, /* atol */ 1e-2));
EXPECT_TRUE(xt::allclose(expected_output, output_xtensor[1], /* rtol */ 1e-3, /* atol */ 1e-2));
};

TEST_F(N300TensorParallelLinearTest, ColumnParallelLinearHasBiasAllGather) {
uint32_t in_features = 64U;
uint32_t out_features = 64U;
bool has_bias = true;
bool use_all_gather = true;

auto layer = ttml::modules::distributed::ColumnParallelLinear(in_features, out_features, has_bias, use_all_gather);
auto parameters = layer.parameters();
EXPECT_EQ(parameters.size(), 1UL + static_cast<size_t>(has_bias));

auto weight = get_parameter(parameters, "weight");
auto bias = get_parameter(parameters, "bias");

auto* device = &ttml::autograd::ctx().get_device();
auto mesh_shape = device->shape();

xt::xarray<float> test_data = xt::random::rand({in_features}, 0.F, 1.F).reshape({1U, 1U, 1U, in_features});
ttml::core::XTensorToMeshVariant<float> replicate_composer = ttml::core::ReplicateXTensorToMesh<float>(mesh_shape);
auto tt_tensor = ttml::core::from_xtensor<float, DataType::BFLOAT16>(test_data, device, replicate_composer);
auto tensor = ttml::autograd::create_tensor(tt_tensor);
auto output = layer(tensor);

ttml::core::MeshToXTensorVariant<float> identity_composer = ttml::core::VectorMeshToXTensor<float>(mesh_shape);
auto output_xtensor = ttml::core::to_xtensor<float>(output->get_value(), identity_composer);
EXPECT_TRUE(xt::allclose(output_xtensor[0], output_xtensor[1], /* rtol */ 1e-3, /* atol */ 1e-2));

ttml::core::MeshToXTensorVariant<float> concat_composer_2 = ttml::core::ConcatMeshToXTensor<float>(mesh_shape, 2U);
ttml::core::MeshToXTensorVariant<float> concat_composer_3 = ttml::core::ConcatMeshToXTensor<float>(mesh_shape, 3U);
// (1, 1, out_features, in_features)
auto weight_xtensor = ttml::core::to_xtensor<float>(weight->get_value(), concat_composer_2);
auto bias_xtensor = ttml::core::to_xtensor<float>(bias->get_value(), concat_composer_3);

auto expected_output = xt::linalg::dot(test_data, xt::transpose(weight_xtensor[0], {0, 1, 3, 2}));
if (has_bias) {
expected_output += bias_xtensor[0];
}

EXPECT_TRUE(xt::allclose(expected_output, output_xtensor[0], /* rtol */ 1e-2, /* atol */ 1e-2));
EXPECT_TRUE(xt::allclose(expected_output, output_xtensor[1], /* rtol */ 1e-2, /* atol */ 1e-2));
};

TEST_F(N300TensorParallelLinearTest, ColumnParallelLinearNoBiasAllGather) {
uint32_t in_features = 64U;
uint32_t out_features = 64U;
bool has_bias = false;
bool use_all_gather = true;

auto layer = ttml::modules::distributed::ColumnParallelLinear(in_features, out_features, has_bias, use_all_gather);
auto parameters = layer.parameters();
EXPECT_EQ(parameters.size(), 1UL + static_cast<size_t>(has_bias));

auto weight = get_parameter(parameters, "weight");

auto* device = &ttml::autograd::ctx().get_device();
auto mesh_shape = device->shape();

xt::xarray<float> test_data = xt::random::rand({in_features}, 0.F, 1.F).reshape({1U, 1U, 1U, in_features});
ttml::core::XTensorToMeshVariant<float> replicate_composer = ttml::core::ReplicateXTensorToMesh<float>(mesh_shape);
auto tt_tensor = ttml::core::from_xtensor<float, DataType::BFLOAT16>(test_data, device, replicate_composer);
auto tensor = ttml::autograd::create_tensor(tt_tensor);
auto output = layer(tensor);

ttml::core::MeshToXTensorVariant<float> identity_composer = ttml::core::VectorMeshToXTensor<float>(mesh_shape);
auto output_xtensor = ttml::core::to_xtensor<float>(output->get_value(), identity_composer);
EXPECT_TRUE(xt::allclose(output_xtensor[0], output_xtensor[1], /* rtol */ 1e-3, /* atol */ 1e-2));

ttml::core::MeshToXTensorVariant<float> concat_composer_2 = ttml::core::ConcatMeshToXTensor<float>(mesh_shape, 2U);
ttml::core::MeshToXTensorVariant<float> concat_composer_3 = ttml::core::ConcatMeshToXTensor<float>(mesh_shape, 3U);
// (1, 1, out_features, in_features)
auto weight_xtensor = ttml::core::to_xtensor<float>(weight->get_value(), concat_composer_2);
auto expected_output = xt::linalg::dot(test_data, xt::transpose(weight_xtensor[0], {0, 1, 3, 2}));

EXPECT_TRUE(xt::allclose(expected_output, output_xtensor[0], /* rtol */ 1e-2, /* atol */ 1e-2));
EXPECT_TRUE(xt::allclose(expected_output, output_xtensor[1], /* rtol */ 1e-2, /* atol */ 1e-2));
};

TEST_F(N300TensorParallelLinearTest, ColumnParallelLinearHasBiasNoAllGather) {
uint32_t in_features = 64U;
uint32_t out_features = 64U;
bool has_bias = true;
bool use_all_gather = false;

auto layer = ttml::modules::distributed::ColumnParallelLinear(in_features, out_features, has_bias, use_all_gather);
auto parameters = layer.parameters();
EXPECT_EQ(parameters.size(), 1UL + static_cast<size_t>(has_bias));

auto weight = get_parameter(parameters, "weight");
auto bias = get_parameter(parameters, "bias");

auto* device = &ttml::autograd::ctx().get_device();
auto mesh_shape = device->shape();

xt::xarray<float> test_data = xt::random::rand({in_features}, 0.F, 1.F).reshape({1U, 1U, 1U, in_features});
ttml::core::XTensorToMeshVariant<float> replicate_composer = ttml::core::ReplicateXTensorToMesh<float>(mesh_shape);
auto tt_tensor = ttml::core::from_xtensor<float, DataType::BFLOAT16>(test_data, device, replicate_composer);
auto tensor = ttml::autograd::create_tensor(tt_tensor);
auto output = layer(tensor);

ttml::core::MeshToXTensorVariant<float> identity_composer = ttml::core::VectorMeshToXTensor<float>(mesh_shape);
auto output_xtensor = ttml::core::to_xtensor<float>(output->get_value(), identity_composer);

ttml::core::MeshToXTensorVariant<float> concat_composer_2 = ttml::core::ConcatMeshToXTensor<float>(mesh_shape, 2U);
ttml::core::MeshToXTensorVariant<float> concat_composer_3 = ttml::core::ConcatMeshToXTensor<float>(mesh_shape, 3U);
// (1, 1, out_features, in_features)
auto weight_xtensor = ttml::core::to_xtensor<float>(weight->get_value(), concat_composer_2);
auto bias_xtensor = ttml::core::to_xtensor<float>(bias->get_value(), concat_composer_3);

auto expected_output = xt::linalg::dot(test_data, xt::transpose(weight_xtensor[0], {0, 1, 3, 2}));
expected_output = expected_output.reshape({1U, 1U, 1U, out_features});
if (has_bias) {
expected_output += bias_xtensor[0];
}

EXPECT_TRUE(xt::allclose(
xt::view(expected_output, xt::all(), xt::all(), xt::all(), xt::range(0, out_features / 2)),
output_xtensor[0],
/* rtol */ 1e-2,
/* atol */ 1e-2));
EXPECT_TRUE(xt::allclose(
xt::view(expected_output, xt::all(), xt::all(), xt::all(), xt::range(out_features / 2, out_features)),
output_xtensor[1],
/* rtol */ 1e-2,
/* atol */ 1e-2));
};

TEST_F(N300TensorParallelLinearTest, ColumnParallelLinearNoBiasNoAllGather) {
uint32_t in_features = 64U;
uint32_t out_features = 64U;
bool has_bias = false;
bool use_all_gather = false;

auto layer = ttml::modules::distributed::ColumnParallelLinear(in_features, out_features, has_bias, use_all_gather);
auto parameters = layer.parameters();
EXPECT_EQ(parameters.size(), 1UL + static_cast<size_t>(has_bias));

auto weight = get_parameter(parameters, "weight");

auto* device = &ttml::autograd::ctx().get_device();
auto mesh_shape = device->shape();

xt::xarray<float> test_data = xt::random::rand({in_features}, 0.F, 1.F).reshape({1U, 1U, 1U, in_features});
ttml::core::XTensorToMeshVariant<float> replicate_composer = ttml::core::ReplicateXTensorToMesh<float>(mesh_shape);
auto tt_tensor = ttml::core::from_xtensor<float, DataType::BFLOAT16>(test_data, device, replicate_composer);
auto tensor = ttml::autograd::create_tensor(tt_tensor);
auto output = layer(tensor);

ttml::core::MeshToXTensorVariant<float> identity_composer = ttml::core::VectorMeshToXTensor<float>(mesh_shape);
auto output_xtensor = ttml::core::to_xtensor<float>(output->get_value(), identity_composer);

ttml::core::MeshToXTensorVariant<float> concat_composer_2 = ttml::core::ConcatMeshToXTensor<float>(mesh_shape, 2U);
ttml::core::MeshToXTensorVariant<float> concat_composer_3 = ttml::core::ConcatMeshToXTensor<float>(mesh_shape, 3U);
// (1, 1, out_features, in_features)
auto weight_xtensor = ttml::core::to_xtensor<float>(weight->get_value(), concat_composer_2);

auto expected_output = xt::linalg::dot(test_data, xt::transpose(weight_xtensor[0], {0, 1, 3, 2}));
expected_output = expected_output.reshape({1U, 1U, 1U, out_features});

EXPECT_TRUE(xt::allclose(
xt::view(expected_output, xt::all(), xt::all(), xt::all(), xt::range(0, out_features / 2)),
output_xtensor[0],
/* rtol */ 1e-2,
/* atol */ 1e-2));
EXPECT_TRUE(xt::allclose(
xt::view(expected_output, xt::all(), xt::all(), xt::all(), xt::range(out_features / 2, out_features)),
output_xtensor[1],
/* rtol */ 1e-2,
/* atol */ 1e-2));
};
Loading