Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tt-train] Enable tensor parallel for MNIST #17506

Merged
merged 15 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
MNIST works with TP?
  • Loading branch information
rfurko-tt committed Feb 2, 2025
commit d3a61b7fd5a788240ce1b472bea852ef74270861
23 changes: 20 additions & 3 deletions tt-train/sources/examples/mnist_mlp/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,14 @@ float evaluate(DataLoader &test_dataloader, Model &model, size_t num_targets) {
model->eval();
float num_correct = 0;
float num_samples = 0;
auto *device = &ttml::autograd::ctx().get_device();
for (const auto &[data, target] : test_dataloader) {
auto output = (*model)(data);
auto output_vec = ttml::core::to_vector(output->get_value());
auto target_vec = ttml::core::to_vector(target->get_value());
ttml::core::MeshToXTensorVariant<float> composer = ttml::core::VectorMeshToXTensor<float>(device->shape());
auto output_xtensor = ttml::core::to_xtensor(output->get_value(), composer)[0];
auto target_xtensor = ttml::core::to_xtensor(target->get_value(), composer)[0];
auto output_vec = std::vector<float>(output_xtensor.begin(), output_xtensor.end());
auto target_vec = std::vector<float>(target_xtensor.begin(), target_xtensor.end());
for (size_t i = 0; i < output_vec.size(); i += num_targets) {
auto predicted_class = std::distance(
output_vec.begin() + i,
Expand Down Expand Up @@ -194,12 +198,25 @@ int main(int argc, char **argv) {

LossAverageMeter loss_meter;
int training_step = 0;

auto get_loss_value = [device](const TensorPtr &loss) {
ttml::core::MeshToXTensorVariant<float> composer = ttml::core::VectorMeshToXTensor<float>(device->shape());
auto loss_xtensors = ttml::core::to_xtensor(loss->get_value(), composer);
// sum of loss xtensors
float loss_float =
std::accumulate(loss_xtensors.begin(), loss_xtensors.end(), 0.0F, [](float acc, auto &xtensor) {
return acc + xtensor(0);
});

return loss_float / static_cast<float>(loss_xtensors.size());
};

for (size_t epoch = 0; epoch < config.num_epochs; ++epoch) {
for (const auto &[data, target] : train_dataloader) {
optimizer.zero_grad();
auto output = (*model)(data);
auto loss = ttml::ops::cross_entropy_loss(output, target);
auto loss_float = ttml::core::to_vector(loss->get_value())[0];
auto loss_float = get_loss_value(loss);
loss_meter.update(loss_float, config.batch_size);
if (training_step % config.logging_interval == 0) {
fmt::print("Step: {:5d} | Average Loss: {:.4f}\n", training_step, loss_meter.average());
Expand Down
2 changes: 1 addition & 1 deletion tt-train/sources/ttml/modules/distributed/linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ autograd::TensorPtr RowParallelLinear::operator()(autograd::TensorPtr tensor) {

// do not pass bias
tensor = ops::linear_op(tensor, m_weight, /* bias */ nullptr);
tensor = ops::distributed::all_reduce(tensor, tensor->rank() - 1U);
tensor = ops::distributed::all_reduce(tensor);
if (m_bias != nullptr) {
tensor = ops::add(tensor, m_bias);
}
Expand Down
6 changes: 3 additions & 3 deletions tt-train/sources/ttml/ops/distributed/comm_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ autograd::TensorPtr all_gather(const autograd::TensorPtr& tensor, int dim) {
return out;
}

autograd::TensorPtr all_reduce(const autograd::TensorPtr& tensor, int dim) {
auto out = autograd::create_tensor(
ttnn::experimental::all_reduce(tensor->get_value(), ttnn::operations::reduction::ReduceType::Sum, dim));
autograd::TensorPtr all_reduce(const autograd::TensorPtr& tensor) {
auto out = autograd::create_tensor(ttnn::experimental::all_reduce(
tensor->get_value(), ttnn::operations::reduction::ReduceType::Sum, 1, std::nullopt, ttnn::ccl::Topology::Ring));
autograd::GradFunction grad = [tensor, out]() { tensor->set_grad(out->get_grad()); };
auto links = autograd::get_links(tensor);
out->set_node(autograd::ctx().add_backward_node(std::move(grad), links));
Expand Down
2 changes: 1 addition & 1 deletion tt-train/sources/ttml/ops/distributed/comm_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
namespace ttml::ops::distributed {

autograd::TensorPtr scatter(const autograd::TensorPtr& tensor, int dim);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this file.

autograd::TensorPtr all_reduce(const autograd::TensorPtr& tensor, int dim);
autograd::TensorPtr all_reduce(const autograd::TensorPtr& tensor);
autograd::TensorPtr all_gather(const autograd::TensorPtr& tensor, int dim);

} // namespace ttml::ops::distributed