Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tt-train] Enable tensor parallel for MNIST #17506

Merged
merged 15 commits into from
Feb 5, 2025
Merged

Conversation

rfurko-tt
Copy link
Contributor

@rfurko-tt rfurko-tt commented Feb 3, 2025

Problem description

Enable tensor parallel training in tt-train. Start with MNIST.

What's changed

  • Added all_reduce, all_gather, scatter operations
  • Add RowParallelLinear and ColumnParallelLinear layers
  • Add support for tensor parallel training for MNIST
  • Model serialization is out of scope for this PR

with tensor parallel
mnist_tp

without tensor parallel
mnist_without_tp

Checklist


namespace ttml::ops::distributed {

autograd::TensorPtr scatter(const autograd::TensorPtr& tensor, int dim);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this file.

@rfurko-tt rfurko-tt changed the title [WIP] tt-train tensor parallel [WIP][tt-train] Enable tensor parallel for MNIST Feb 4, 2025
@rfurko-tt rfurko-tt changed the title [WIP][tt-train] Enable tensor parallel for MNIST [tt-train] Enable tensor parallel for MNIST Feb 4, 2025
@rfurko-tt rfurko-tt merged commit 886377c into main Feb 5, 2025
192 of 195 checks passed
@rfurko-tt rfurko-tt deleted the rfurko/tensor_parallel branch February 5, 2025 17:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants