Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tt-train] Enable tensor parallel for MNIST #17506

Merged
merged 15 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Move std::visit to separate functions. Move MNIST model to model file
  • Loading branch information
rfurko-tt committed Feb 5, 2025
commit d7e255b85e660ea28dcacf15b9d0b4ed11e1abf0
1 change: 1 addition & 0 deletions tt-train/sources/examples/mnist_mlp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ project(mnist_mlp)
set(SOURCES
main.cpp
utils.cpp
model.cpp
)
CPMAddPackage(NAME mnist_dataset GITHUB_REPOSITORY wichtounet/mnist GIT_TAG master)
include_directories(${mnist_dataset_SOURCE_DIR}/include)
Expand Down
165 changes: 89 additions & 76 deletions tt-train/sources/examples/mnist_mlp/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
//
// SPDX-License-Identifier: Apache-2.0

#include <yaml-cpp/node/node.h>

#include <CLI/CLI.hpp>
#include <core/ttnn_all_includes.hpp>
#include <functional>
Expand All @@ -14,12 +16,12 @@
#include "core/tt_tensor_utils.hpp"
#include "datasets/dataloader.hpp"
#include "datasets/in_memory_dataset.hpp"
#include "model.hpp"
#include "models/mlp.hpp"
#include "modules/distributed/linear.hpp"
#include "ops/losses.hpp"
#include "optimizers/sgd.hpp"
#include "serialization/serializable.hpp"
#include "utils.hpp"
#include "yaml-cpp/node/node.h"

using ttml::autograd::TensorPtr;

Expand All @@ -30,60 +32,10 @@ using DataLoader = ttml::datasets::DataLoader<
std::function<BatchType(std::vector<DatasetSample> &&samples)>,
BatchType>;

constexpr auto model_name = "mlp";
constexpr auto optimizer_name = "optimizer";

class MnistTP : public ttml::autograd::ModuleBase {
public:
MnistTP() {
m_linear1 = std::make_shared<ttml::modules::distributed::ColumnParallelLinear>(
784, 128, /* has_bias */ true, /* gather_output */ false);
m_linear2 = std::make_shared<ttml::modules::distributed::RowParallelLinear>(
128, 10, /* has_bias */ true, /* input_is_parallel */ true);
create_name(model_name);
register_module(m_linear1, "linear1");
register_module(m_linear2, "linear2");
}
using Model = std::variant<std::shared_ptr<ttml::modules::MultiLayerPerceptron>, std::shared_ptr<MNISTTensorParallel>>;

ttml::autograd::TensorPtr operator()(ttml::autograd::TensorPtr tensor) {
tensor = (*m_linear1)(tensor);
tensor = ttml::ops::relu(tensor);
tensor = (*m_linear2)(tensor);
return tensor;
}

private:
std::shared_ptr<ttml::modules::distributed::ColumnParallelLinear> m_linear1;
std::shared_ptr<ttml::modules::distributed::RowParallelLinear> m_linear2;
};

template <typename Model>
float evaluate(DataLoader &test_dataloader, Model &model, size_t num_targets) {
std::visit([](auto &model) { model->eval(); }, model);
float num_correct = 0;
float num_samples = 0;
auto *device = &ttml::autograd::ctx().get_device();
for (const auto &[data, target] : test_dataloader) {
auto output = std::visit([&data](auto &model) { return (*model)(data); }, model);
ttml::core::MeshToXTensorVariant<float> composer = ttml::core::VectorMeshToXTensor<float>(device->shape());
auto output_xtensor = ttml::core::to_xtensor(output->get_value(), composer)[0];
auto target_xtensor = ttml::core::to_xtensor(target->get_value(), composer)[0];
auto output_vec = std::vector<float>(output_xtensor.begin(), output_xtensor.end());
auto target_vec = std::vector<float>(target_xtensor.begin(), target_xtensor.end());
for (size_t i = 0; i < output_vec.size(); i += num_targets) {
auto predicted_class = std::distance(
output_vec.begin() + i,
std::max_element(output_vec.begin() + i, output_vec.begin() + (i + num_targets)));
auto target_class = std::distance(
target_vec.begin() + i,
std::max_element(target_vec.begin() + i, target_vec.begin() + (i + num_targets)));
num_correct += static_cast<float>(predicted_class == target_class);
num_samples++;
}
}
std::visit([](auto &model) { model->train(); }, model);
return num_correct / num_samples;
};
const std::string model_name = "mlp";
const std::string optimizer_name = "optimizer";

struct TrainingConfig {
uint32_t batch_size = 128;
Expand Down Expand Up @@ -119,6 +71,76 @@ void initialize_device(bool enable_tp) {
}
}

void model_to_eval(Model &model) {
std::visit([](auto &model) { model->eval(); }, model);
}

void model_to_train(Model &model) {
std::visit([](auto &model) { model->train(); }, model);
}

ttml::autograd::TensorPtr run_model(Model &model, const ttml::autograd::TensorPtr &data) {
return std::visit([&data](auto &model) { return (*model)(data); }, model);
}

ttml::serialization::NamedParameters get_model_parameters(Model &model) {
return std::visit([](auto &model) { return model->parameters(); }, model);
}

void load_model(
Model &model,
const TrainingConfig &config,
ttml::optimizers::SGD &optimizer,
const std::string &model_name,
const std::string &optimizer_name) {
std::visit(
[&config, &optimizer, &model_name, &optimizer_name](auto &model) {
load_training_state(config.model_path, model, optimizer, model_name, optimizer_name);
},
model);
}

void save_model(
Model &model,
const TrainingConfig &config,
ttml::optimizers::SGD &optimizer,
const std::string &model_name,
const std::string &optimizer_name) {
std::visit(
[&config, &optimizer, &model_name, &optimizer_name](auto &model) {
save_training_state(config.model_path, model, optimizer, model_name, optimizer_name);
},
model);
}

template <typename Model>
float evaluate(DataLoader &test_dataloader, Model &model, size_t num_targets) {
model_to_eval(model);
float num_correct = 0;
float num_samples = 0;
auto *device = &ttml::autograd::ctx().get_device();
for (const auto &[data, target] : test_dataloader) {
auto output = run_model(model, data);
ttml::core::MeshToXTensorVariant<float> composer = ttml::core::VectorMeshToXTensor<float>(device->shape());
auto output_xtensor = ttml::core::to_xtensor(output->get_value(), composer)[0];
auto target_xtensor = ttml::core::to_xtensor(target->get_value(), composer)[0];
auto output_vec = std::vector<float>(output_xtensor.begin(), output_xtensor.end());
auto target_vec = std::vector<float>(target_xtensor.begin(), target_xtensor.end());
for (size_t i = 0; i < output_vec.size(); i += num_targets) {
auto predicted_class = std::distance(
output_vec.begin() + i,
std::max_element(output_vec.begin() + i, output_vec.begin() + (i + num_targets)));
auto target_class = std::distance(
target_vec.begin() + i,
std::max_element(target_vec.begin() + i, target_vec.begin() + (i + num_targets)));
num_correct += static_cast<float>(predicted_class == target_class);
num_samples++;
}
}
model_to_train(model);
return num_correct / num_samples;
};

int main(int argc, char **argv) {
CLI::App app{"Mnist Example"};
argv = app.ensure_utf8(argv);
Expand Down Expand Up @@ -174,9 +196,9 @@ int main(int argc, char **argv) {
auto train_dataloader = DataLoader(training_dataset, config.batch_size, /* shuffle */ true, collate_fn);
auto test_dataloader = DataLoader(test_dataset, config.batch_size, /* shuffle */ false, collate_fn);

std::variant<std::shared_ptr<ttml::modules::MultiLayerPerceptron>, std::shared_ptr<MnistTP>> model;
Model model;
if (enable_tp) {
model = std::make_shared<MnistTP>();
model = std::make_shared<MNISTTensorParallel>();
} else {
model = ttml::models::mlp::create(config.mlp_config);
}
Expand All @@ -193,15 +215,15 @@ int main(int argc, char **argv) {
fmt::print(" Dampening {}\n", sgd_config.dampening);
fmt::print(" Weight decay: {}\n", sgd_config.weight_decay);
fmt::print(" Nesterov: {}\n", sgd_config.nesterov);
auto parameters = std::visit([](auto &model) { return model->parameters(); }, model);
auto parameters = get_model_parameters(model);
auto optimizer = ttml::optimizers::SGD(parameters, sgd_config);
if (!config.model_path.empty() && std::filesystem::exists(config.model_path)) {
fmt::print("Loading model from {}\n", config.model_path);
std::visit(
[&config, &optimizer, model_name = model_name, optimizer_name = optimizer_name](auto &model) {
load_training_state(config.model_path, model, optimizer, model_name, optimizer_name);
},
model);
if (enable_tp) {
fmt::println("Loading model for tensor parallelism is not supported yet. Loading model has been skipped.");
} else {
fmt::print("Loading model from {}\n", config.model_path);
load_model(model, config, optimizer, model_name, optimizer_name);
}
}

// evaluate model before training (sanity check to get reasonable accuracy
Expand Down Expand Up @@ -230,7 +252,7 @@ int main(int argc, char **argv) {
for (size_t epoch = 0; epoch < config.num_epochs; ++epoch) {
for (const auto &[data, target] : train_dataloader) {
optimizer.zero_grad();
auto output = std::visit([&data](auto &model) { return (*model)(data); }, model);
auto output = run_model(model, data);
auto loss = ttml::ops::cross_entropy_loss(output, target);
auto loss_float = get_loss_value(loss);
loss_meter.update(loss_float, config.batch_size);
Expand All @@ -239,11 +261,7 @@ int main(int argc, char **argv) {
}
if (!config.model_path.empty() && training_step % config.model_save_interval == 0) {
fmt::print("Saving model to {}\n", config.model_path);
std::visit(
[&config, &optimizer, model_name = model_name, optimizer_name = optimizer_name](auto &model) {
save_training_state(config.model_path, model, optimizer, model_name, optimizer_name);
},
model);
save_model(model, config, optimizer, model_name, optimizer_name);
}

loss->backward();
Expand All @@ -263,12 +281,7 @@ int main(int argc, char **argv) {

if (!config.model_path.empty()) {
fmt::print("Saving model to {}\n", config.model_path);
// save_training_state(config.model_path, model, optimizer, model_name, optimizer_name);
std::visit(
[&config, &optimizer, model_name = model_name, optimizer_name = optimizer_name](auto &model) {
save_training_state(config.model_path, model, optimizer, model_name, optimizer_name);
},
model);
save_model(model, config, optimizer, model_name, optimizer_name);
}

return 0;
Expand Down
24 changes: 24 additions & 0 deletions tt-train/sources/examples/mnist_mlp/model.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "model.hpp"

#include "ops/unary_ops.hpp"

MNISTTensorParallel::MNISTTensorParallel() {
m_linear1 = std::make_shared<ttml::modules::distributed::ColumnParallelLinear>(
784, 128, /* has_bias */ true, /* gather_output */ false);
m_linear2 = std::make_shared<ttml::modules::distributed::RowParallelLinear>(
128, 10, /* has_bias */ true, /* input_is_parallel */ true);
create_name("mlp");
register_module(m_linear1, "linear1");
register_module(m_linear2, "linear2");
}

ttml::autograd::TensorPtr MNISTTensorParallel::operator()(ttml::autograd::TensorPtr tensor) {
tensor = (*m_linear1)(tensor);
tensor = ttml::ops::relu(tensor);
tensor = (*m_linear2)(tensor);
return tensor;
}
18 changes: 18 additions & 0 deletions tt-train/sources/examples/mnist_mlp/model.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "autograd/module_base.hpp"
#include "modules/distributed/linear.hpp"

class MNISTTensorParallel : public ttml::autograd::ModuleBase {
public:
MNISTTensorParallel();
ttml::autograd::TensorPtr operator()(ttml::autograd::TensorPtr tensor);

private:
std::shared_ptr<ttml::modules::distributed::ColumnParallelLinear> m_linear1;
std::shared_ptr<ttml::modules::distributed::RowParallelLinear> m_linear2;
};
4 changes: 2 additions & 2 deletions tt-train/sources/examples/mnist_mlp/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class Timers {

template <typename Model, typename Optimizer>
void save_training_state(
std::string &model_path,
const std::string &model_path,
const std::shared_ptr<Model> &model,
Optimizer &optimizer,
const std::string &model_name,
Expand All @@ -52,7 +52,7 @@ void save_training_state(

template <typename Model, typename Optimizer>
void load_training_state(
std::string &model_path,
const std::string &model_path,
const std::shared_ptr<Model> &model,
Optimizer &optimizer,
const std::string &model_name,
Expand Down
Loading