Skip to content

Commit

Permalink
Change app/lm to use fl::Tensor
Browse files Browse the repository at this point in the history
Summary: See title.

Reviewed By: benoitsteiner

Differential Revision: D33521672

fbshipit-source-id: 627cb48045829b3289ef4a1de9e29cc48e8525f3
  • Loading branch information
jacobkahn authored and facebook-github-bot committed May 3, 2022
1 parent a305da8 commit d82bd3b
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 51 deletions.
1 change: 1 addition & 0 deletions flashlight/app/lm/Test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
*/

#include "flashlight/app/lm/Trainer.h"
#include "flashlight/fl/tensor/Init.h"

using namespace fl::pkg::runtime;
using namespace fl::lib;
Expand Down
1 change: 1 addition & 0 deletions flashlight/app/lm/Train.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
*/

#include "flashlight/app/lm/Trainer.h"
#include "flashlight/fl/tensor/Init.h"

using namespace fl::pkg::runtime;
using namespace fl::lib;
Expand Down
80 changes: 43 additions & 37 deletions flashlight/app/lm/Trainer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@
*/

#include "flashlight/app/lm/Trainer.h"

#include <algorithm>

#include "flashlight/fl/tensor/Compute.h"
#include "flashlight/fl/tensor/Index.h"
#include "flashlight/fl/tensor/TensorBase.h"

using namespace fl::pkg::runtime;
using namespace fl::pkg::runtime;
using namespace fl::pkg::text;
Expand Down Expand Up @@ -241,7 +246,7 @@ Trainer::Trainer(const std::string& mode) {
gflagsStr_ = fl::pkg::runtime::serializeGflags();
FL_LOG_MASTER(INFO) << "Gflags after parsing \n" << serializeGflags("; ");

initArrayFire();
this->init();
if (FLAGS_distributed_enable) {
reducer_ = std::make_shared<fl::CoalescingReducer>(1.0, true, true);
}
Expand Down Expand Up @@ -332,7 +337,8 @@ void Trainer::trainStep() {
fl::Variable input, target;
sampleTimerMeter_.resume();
std::tie(input, target) = getInputAndTarget(trainDataset_->get(batchIdx_));
af::array inputSizes = af::sum(input.array() != kPadIdx_, 0);
Tensor inputSizes =
fl::sum(input.tensor() != kPadIdx_, {0}, /* keepDims = */ true);
sampleTimerMeter_.stopAndIncUnit();

while (true) {
Expand All @@ -349,8 +355,9 @@ void Trainer::trainStep() {
// 3. Backward
bwdTimeMeter_.resume();
optimizer_->zeroGrad();
float numTokens = af::count<float>(target.array() != kPadIdx_);
af::array numTokensArr = af::array(1, &numTokens);
float numTokens =
fl::countNonzero(target.tensor() != kPadIdx_).asScalar<float>();
Tensor numTokensArr = fl::fromScalar(numTokens);
if (FLAGS_distributed_enable) {
fl::allReduce(numTokensArr);
}
Expand All @@ -373,7 +380,8 @@ void Trainer::trainStep() {
if (numTokens > 0) {
auto weight =
numTokens / (FLAGS_data_tokens_per_sample * FLAGS_data_batch_size);
trainLossMeter_.add(af::mean<float>(loss.array()) / numTokens, weight);
trainLossMeter_.add(
fl::mean(loss.tensor()).scalar<float>() / numTokens, weight);
tokenCountMeter_.add(numTokens);
}
break;
Expand All @@ -394,15 +402,17 @@ void Trainer::evalStep() {
for (const auto& sample : *validDataset_) {
fl::Variable input, target;
std::tie(input, target) = getInputAndTarget(sample);
af::array inputSizes = af::sum(input.array() != kPadIdx_, 0);
Tensor inputSizes = fl::sum(input.tensor() != kPadIdx_, {0});
auto output = network_->forward({input, fl::noGrad(inputSizes)}).front();
auto loss = criterion_->forward({output, target}).front();
auto numTokens = af::count<int>(target.array() != kPadIdx_);
auto numTokens =
fl::countNonzero(target.tensor() != kPadIdx_).scalar<unsigned>();
if (numTokens > 0) {
auto weight = numTokens /
static_cast<double>(
FLAGS_data_tokens_per_sample * FLAGS_data_batch_size);
validLossMeter_.add(af::mean<double>(loss.array()) / numTokens, weight);
validLossMeter_.add(
fl::mean(loss.tensor()).asScalar<double>() / numTokens, weight);
}
}
}
Expand Down Expand Up @@ -629,62 +639,62 @@ void Trainer::createOptimizer() {

/* ============= Stateful training helpers ============= */
std::pair<fl::Variable, fl::Variable> Trainer::getInputAndTarget(
const std::vector<af::array>& sample) const {
const std::vector<Tensor>& sample) const {
// sample.size() == 1
// sample[0] has size T x B
fl::Variable input, target;
auto T = sample[0].dims(0);
auto T = sample[0].dim(0);

if (FLAGS_train_task == "mask") {
// TODO: need cleaning + correctness checking

// do masking of input and target
af::array randMatrix = af::randu(sample[0].dims());
af::array randMatrixSorted, randMatrixSortedIndices;
Tensor randMatrix = fl::rand(sample[0].shape());
Tensor randMatrixSorted, randMatrixSortedIndices;
// create random permutation
af::sort(randMatrixSorted, randMatrixSortedIndices, randMatrix, 0);
randMatrixSortedIndices = af::flat(randMatrixSortedIndices);
fl::sort(randMatrixSorted, randMatrixSortedIndices, randMatrix, 0);
randMatrixSortedIndices = randMatrixSortedIndices.flatten();

af::array inputMasked = af::flat(sample[0]);
Tensor inputMasked = sample[0].flatten();
// set min length of the masked tokens
int nTotalMask =
std::max(int(FLAGS_mask_prob * T), (int)FLAGS_mask_min_length);
// set total mask
af::array totalMask = randMatrixSortedIndices < nTotalMask;
af::array notMasked = !totalMask;
af::array woMaskTokenMask = randMatrixSortedIndices <
Tensor totalMask = randMatrixSortedIndices < nTotalMask;
Tensor notMasked = !totalMask;
Tensor woMaskTokenMask = randMatrixSortedIndices <
(FLAGS_mask_rand_token_prob + FLAGS_mask_same_token_prob) * nTotalMask;
af::array randMask =
Tensor randMask =
randMatrixSortedIndices < FLAGS_mask_rand_token_prob * nTotalMask;

inputMasked(totalMask) = kMaskIdx_;
inputMasked(woMaskTokenMask) = af::flat(sample[0])(woMaskTokenMask);
if (af::anyTrue<bool>(randMask)) {
inputMasked(woMaskTokenMask) = sample[0].flatten()(woMaskTokenMask);
if (fl::any(randMask).asScalar<bool>()) {
// exclude 4 special tokens from the consideration: pad, eos, unk and
// mask
std::vector<int> specialTokens = {
kPadIdx_, kEosIdx_, kUnkIdx_, kMaskIdx_};
std::sort(specialTokens.begin(), specialTokens.end());
auto randVals = (af::randu(af::sum(randMask).scalar<unsigned int>()) *
auto randVals = (fl::rand({fl::sum(randMask).asScalar<unsigned int>()}) *
(dictionary_.entrySize() - 1 - specialTokens.size()))
.as(s32);
.astype(fl::dtype::s32);
for (auto specialVal : specialTokens) {
auto specialMask = randVals >= specialVal;
randVals(specialMask) = randVals(specialMask) + 1;
}
inputMasked(randMask) = randVals;
}
// fix position where it was pad index to be pad
inputMasked(af::flat(sample[0] == kPadIdx_)) = kPadIdx_;
inputMasked = af::moddims(inputMasked, sample[0].dims());
inputMasked((sample[0] == kPadIdx_).flatten()) = kPadIdx_;
inputMasked = fl::reshape(inputMasked, sample[0].shape());
input = fl::Variable(inputMasked, false);
auto targetMasked = af::flat(sample[0]);
auto targetMasked = sample[0].flatten();
targetMasked(notMasked) = kPadIdx_;
targetMasked = af::moddims(targetMasked, sample[0].dims());
targetMasked = fl::reshape(targetMasked, sample[0].shape());
target = fl::Variable(targetMasked, false);
} else if (FLAGS_train_task == "autoreg") {
input = fl::Variable(sample[0](af::seq(0, T - 2), af::span), false);
target = fl::Variable(sample[0](af::seq(1, T - 1), af::span), false);
input = fl::Variable(sample[0](fl::range(0, T - 1), fl::span), false);
target = fl::Variable(sample[0](fl::range(1, T), fl::span), false);
} else {
throw std::invalid_argument(
"Not supported train_task: " + FLAGS_train_task);
Expand Down Expand Up @@ -726,16 +736,16 @@ void Trainer::reduceGrads() {
if (!p.isGradAvailable()) {
p.addGrad(fl::constant(0.0, p.dims(), p.type(), false));
}
auto& grad = p.grad().array();
p.grad().array() = grad;
auto& grad = p.grad().tensor();
p.grad().tensor() = grad;
reducer_->add(p.grad());
}
reducer_->finalize();
}
}

/* ============= Stateless training helpers ============= */
void Trainer::initArrayFire() const {
void Trainer::init() const {
// Set arrayfire seed for reproducibility
fl::setSeed(FLAGS_train_seed);
}
Expand Down Expand Up @@ -847,11 +857,7 @@ void Trainer::saveCheckpoint(const std::string& path, const std::string& suffix)

void Trainer::logMemoryManagerStatus() const {
if (isMaster()) {
auto* curMemMgr =
fl::MemoryManagerInstaller::currentlyInstalledMemoryManager();
if (curMemMgr) {
curMemMgr->printInfo("Memory Manager Stats", 0 /* device id */);
}
fl::detail::getMemMgrInfo("Memory Manager Stats", /* device id = */ 0);
}
}

Expand Down
14 changes: 7 additions & 7 deletions flashlight/app/lm/Trainer.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,10 @@
#include <gflags/gflags.h>
#include <glog/logging.h>

#include "flashlight/pkg/runtime/Runtime.h"
#include "flashlight/app/lm/common/Defines.h"
#include "flashlight/pkg/runtime/Runtime.h"
#include "flashlight/pkg/text/data/TextDataset.h"

#include "flashlight/pkg/runtime/amp/DynamicScaler.h"
#include "flashlight/pkg/runtime/common/DistributedUtils.h"
#include "flashlight/pkg/runtime/common/Serializer.h"
#include "flashlight/pkg/runtime/plugin/ModulePlugin.h"
#include "flashlight/fl/contrib/contrib.h"
#include "flashlight/fl/flashlight.h"
#include "flashlight/lib/common/String.h"
Expand All @@ -28,6 +24,10 @@
#include "flashlight/lib/text/dictionary/Utils.h"
#include "flashlight/lib/text/tokenizer/PartialFileReader.h"
#include "flashlight/lib/text/tokenizer/Tokenizer.h"
#include "flashlight/pkg/runtime/amp/DynamicScaler.h"
#include "flashlight/pkg/runtime/common/DistributedUtils.h"
#include "flashlight/pkg/runtime/common/Serializer.h"
#include "flashlight/pkg/runtime/plugin/ModulePlugin.h"

namespace fl {
namespace app {
Expand Down Expand Up @@ -158,12 +158,12 @@ class Trainer {

/* Stateful training helpers */
std::pair<fl::Variable, fl::Variable> getInputAndTarget(
const std::vector<af::array>& sample) const;
const std::vector<Tensor>& sample) const;
void setLr();
void reduceGrads();

/* Stateless training helpers */
void initArrayFire() const;
void init() const;
std::vector<int> parseCutoffs(int64_t nClasses) const;
bool isMaster() const;
void checkArgs() const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,14 @@ class LmAdae512SinposL8H8Fc1024Dp03Ldp0Adsm : public fl::Container {
std::vector<fl::Variable> forward(
const std::vector<fl::Variable>& input) override {
auto out = input[0];
auto xSizes = input[1].array();
auto xSizes = input[1].tensor();
// expected input dims T x B x 1 x 1
int T = out.dims(0), B = out.dims(1);
auto inputMaxSize = af::tile(af::max(xSizes), 1, B);
af::array inputNotPaddedSize = af::ceil(xSizes * T / inputMaxSize);
auto padMask = af::iota(af::dim4(T, 1), af::dim4(1, B)) <
af::tile(inputNotPaddedSize, T, 1);
// TODO{fl::Tensor} - first non-signleton dimension check
auto inputMaxSize = fl::tile(fl::amax(xSizes, {0}), {1, B});
Tensor inputNotPaddedSize = fl::ceil(xSizes * T / inputMaxSize);
auto padMask =
fl::iota({T, 1}, {1, B}) < fl::tile(inputNotPaddedSize, {T, 1});
out = frontend_->forward(out);
for (int trIdx = 0; trIdx < transformers_.size(); trIdx++) {
out = transformers_[trIdx]->forward({out, fl::noGrad(padMask)}).front();
Expand Down Expand Up @@ -75,8 +76,7 @@ class LmAdae512SinposL8H8Fc1024Dp03Ldp0Adsm : public fl::Container {
} // namespace rasrLM

extern "C" fl::Module* createModule(int64_t, int64_t nLabel) {
auto m =
std::make_unique<LmAdae512SinposL8H8Fc1024Dp03Ldp0Adsm>(nLabel);
auto m = std::make_unique<LmAdae512SinposL8H8Fc1024Dp03Ldp0Adsm>(nLabel);
return m.release();
}

Expand Down

0 comments on commit d82bd3b

Please sign in to comment.