Skip to content

Commit

Permalink
Change pkg/runtime to use fl::Tensor (flashlight#786)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: flashlight#786

Migrate `pkg/runtime` to use `fl::Tensor`.

Reviewed By: benoitsteiner

Differential Revision: D31984590

fbshipit-source-id: b8aa134735e759aed5c75586afae5e9cea18072d
  • Loading branch information
jacobkahn authored and facebook-github-bot committed May 3, 2022
1 parent 998fbfe commit 55f36a7
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 111 deletions.
6 changes: 3 additions & 3 deletions flashlight/pkg/runtime/amp/DynamicScaler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ DynamicScaler::DynamicScaler(

fl::Variable DynamicScaler::scale(const fl::Variable& loss) {
// Force casting to fp32 to avoid overflow in scaling.
auto scaledLoss = loss.as(af::dtype::f32);
auto scaledLoss = loss.as(fl::dtype::f32);
scaledLoss = scaledLoss * scaleFactor_;
return scaledLoss;
}
Expand All @@ -32,10 +32,10 @@ bool DynamicScaler::unscale(std::vector<fl::Variable>& params) {
for (auto& p : params) {
if (!p.isGradAvailable()) {
// Add a dummy grad for params not used in the backwards pass
p.addGrad(Variable(af::constant(0.0, p.dims(), p.type()), false));
p.addGrad(Variable(fl::full(p.dims(), 0., p.type()), false));
}
p.grad() = p.grad() / scaleFactor_;
if (fl::isInvalidArray(p.grad().array())) {
if (fl::isInvalidArray(p.grad().tensor())) {
if (scaleFactor_ >= fl::kAmpMinimumScaleFactorValue) {
scaleFactor_ = scaleFactor_ / 2.0f;
FL_LOG(INFO) << "AMP: Scale factor decreased. New value:\t"
Expand Down
40 changes: 20 additions & 20 deletions flashlight/pkg/runtime/common/DistributedUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,70 +36,70 @@ void initDistributed(
}
}

af::array allreduceGet(fl::AverageValueMeter& mtr) {
Tensor allreduceGet(fl::AverageValueMeter& mtr) {
auto mtrVal = mtr.value();
mtrVal[0] *= mtrVal[2];
return af::array(mtrVal.size(), mtrVal.data());
return Tensor::fromVector(mtrVal);
}

af::array allreduceGet(fl::EditDistanceMeter& mtr) {
Tensor allreduceGet(fl::EditDistanceMeter& mtr) {
auto mtrVal0 = mtr.value();
std::vector<long long> mtrVal(mtrVal0.begin(), mtrVal0.end());
return af::array(mtrVal.size(), mtrVal.data());
return Tensor::fromVector(mtrVal);
}

af::array allreduceGet(fl::CountMeter& mtr) {
Tensor allreduceGet(fl::CountMeter& mtr) {
auto mtrVal0 = mtr.value();
std::vector<long long> mtrVal(mtrVal0.begin(), mtrVal0.end());
return af::array(mtrVal.size(), mtrVal.data());
return Tensor::fromVector(mtrVal);
}

af::array allreduceGet(fl::TimeMeter& mtr) {
return af::constant(mtr.value(), 1, af::dtype::f64);
Tensor allreduceGet(fl::TimeMeter& mtr) {
return fl::full({1}, mtr.value(), fl::dtype::f64);
}

af::array allreduceGet(fl::TopKMeter& mtr) {
Tensor allreduceGet(fl::TopKMeter& mtr) {
std::pair<int32_t, int32_t> stats = mtr.getStats();
std::vector<int32_t> vec = {stats.first, stats.second};
return af::array(vec.size(), vec.data());
return Tensor::fromVector(vec);
}

void allreduceSet(fl::AverageValueMeter& mtr, af::array& val) {
void allreduceSet(fl::AverageValueMeter& mtr, Tensor& val) {
mtr.reset();
auto valVec = afToVector<double>(val);
auto valVec = val.toHostVector<double>();
if (valVec[2] != 0) {
valVec[0] /= valVec[2];
}
mtr.add(valVec[0], valVec[2]);
}

void allreduceSet(fl::EditDistanceMeter& mtr, af::array& val) {
void allreduceSet(fl::EditDistanceMeter& mtr, Tensor& val) {
mtr.reset();
auto valVec = afToVector<long long>(val);
auto valVec = val.toHostVector<long long>();
mtr.add(
static_cast<int64_t>(valVec[1]),
static_cast<int64_t>(valVec[2]),
static_cast<int64_t>(valVec[3]),
static_cast<int64_t>(valVec[4]));
}

void allreduceSet(fl::CountMeter& mtr, af::array& val) {
void allreduceSet(fl::CountMeter& mtr, Tensor& val) {
mtr.reset();
auto valVec = afToVector<long long>(val);
auto valVec = val.toHostVector<long long>();
for (size_t i = 0; i < valVec.size(); ++i) {
mtr.add(i, valVec[i]);
}
}

void allreduceSet(fl::TimeMeter& mtr, af::array& val) {
void allreduceSet(fl::TimeMeter& mtr, Tensor& val) {
auto worldSize = fl::getWorldSize();
auto valVec = afToVector<double>(val);
auto valVec = val.toHostVector<double>();
mtr.set(valVec[0] / worldSize);
}

void allreduceSet(fl::TopKMeter& mtr, af::array& val) {
void allreduceSet(fl::TopKMeter& mtr, Tensor& val) {
mtr.reset();
auto valVec = afToVector<int32_t>(val);
auto valVec = val.toHostVector<int32_t>();
mtr.set(valVec[0], valVec[1]);
}
} // namespace runtime
Expand Down
29 changes: 15 additions & 14 deletions flashlight/pkg/runtime/common/DistributedUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
#include <string>
#include <vector>

#include "flashlight/fl/flashlight.h"
#include "flashlight/fl/distributed/DistributedApi.h"
#include "flashlight/fl/meter/meters.h"

namespace fl {
class Tensor;

namespace pkg {
namespace runtime {

Expand All @@ -25,17 +28,17 @@ void initDistributed(
int maxDevicesPerNode,
const std::string& rndvFilepath);

af::array allreduceGet(AverageValueMeter& mtr);
af::array allreduceGet(EditDistanceMeter& mtr);
af::array allreduceGet(CountMeter& mtr);
af::array allreduceGet(TimeMeter& mtr);
af::array allreduceGet(TopKMeter& mtr);
Tensor allreduceGet(AverageValueMeter& mtr);
Tensor allreduceGet(EditDistanceMeter& mtr);
Tensor allreduceGet(CountMeter& mtr);
Tensor allreduceGet(TimeMeter& mtr);
Tensor allreduceGet(TopKMeter& mtr);

void allreduceSet(AverageValueMeter& mtr, af::array& val);
void allreduceSet(EditDistanceMeter& mtr, af::array& val);
void allreduceSet(CountMeter& mtr, af::array& val);
void allreduceSet(TimeMeter& mtr, af::array& val);
void allreduceSet(TopKMeter& mtr, af::array& val);
void allreduceSet(AverageValueMeter& mtr, Tensor& val);
void allreduceSet(EditDistanceMeter& mtr, Tensor& val);
void allreduceSet(CountMeter& mtr, Tensor& val);
void allreduceSet(TimeMeter& mtr, Tensor& val);
void allreduceSet(TopKMeter& mtr, Tensor& val);

/**
* Synchronize meters across process.
Expand All @@ -45,7 +48,7 @@ void syncMeter(T& mtr) {
if (!fl::isDistributedInit()) {
return;
}
af::array arr = allreduceGet(mtr);
Tensor arr = allreduceGet(mtr);
fl::allReduce(arr);
allreduceSet(mtr, arr);
}
Expand All @@ -59,5 +62,3 @@ template void syncMeter<TopKMeter>(TopKMeter& mtr);
} // namespace runtime
} // namespace pkg
} // namespace fl

#include "flashlight/pkg/runtime/common/Utils-inl.h"
41 changes: 22 additions & 19 deletions flashlight/pkg/runtime/common/SequentialBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include <stdexcept>

#include "flashlight/fl/tensor/Types.h"
#include "flashlight/lib/common/String.h"
#include "flashlight/lib/common/System.h"

Expand Down Expand Up @@ -62,13 +63,13 @@ std::shared_ptr<Sequential> buildSequentialModule(
fl::Variable forwardSequentialModuleWithPadMask(
const fl::Variable& input,
std::shared_ptr<fl::Module> ntwrk,
const af::array& inputSizes) {
const Tensor& inputSizes) {
// expected input dims T x C x 1 x B
int T = input.dims(0), B = input.dims(3);
auto inputMaxSize = af::tile(af::max(inputSizes), 1, B);
af::array inputNotPaddedSize = af::ceil(inputSizes * T / inputMaxSize);
auto padMask = af::iota(af::dim4(T, 1), af::dim4(1, B)) <
af::tile(inputNotPaddedSize, T, 1);
auto inputMaxSize = fl::tile(fl::amax(inputSizes, {1}), {1, B});
Tensor inputNotPaddedSize = fl::ceil(inputSizes * T / inputMaxSize);
auto padMask =
fl::iota({T, 1}, {1, B}) < fl::tile(inputNotPaddedSize, {T, 1});
auto ntwrkSeq = std::dynamic_pointer_cast<fl::Sequential>(ntwrk);
auto output = input;
for (auto& module : ntwrkSeq->modules()) {
Expand Down Expand Up @@ -107,17 +108,17 @@ std::shared_ptr<Module> parseLines(
/* ========== TRANSFORMATIONS ========== */

if ((params[0] == "RO") || (params[0] == "V")) {
if (params.size() != 5) {
if (params.size() < 2) {
throw std::invalid_argument("Failed parsing - " + line);
}
int dim1 = std::stoi(params[1]);
int dim2 = std::stoi(params[2]);
int dim3 = std::stoi(params[3]);
int dim4 = std::stoi(params[4]);
Shape shape(std::vector<Dim>(params.size() - 1));
for (unsigned i = 1; i < params.size(); ++i) {
shape[i - 1] = std::stoi(params[i]);
}
if (params[0] == "RO") {
return std::make_shared<Reorder>(dim1, dim2, dim3, dim4);
return std::make_shared<Reorder>(shape);
} else {
return std::make_shared<View>(af::dim4(dim1, dim2, dim3, dim4));
return std::make_shared<View>(shape);
}
}

Expand All @@ -127,11 +128,13 @@ std::shared_ptr<Module> parseLines(
}
auto val = std::stod(params[1]);
params.resize(10, "0");
std::pair<int, int> pad0 = {std::stoi(params[2]), std::stoi(params[3])};
std::pair<int, int> pad1 = {std::stoi(params[4]), std::stoi(params[5])};
std::pair<int, int> pad2 = {std::stoi(params[6]), std::stoi(params[7])};
std::pair<int, int> pad3 = {std::stoi(params[8]), std::stoi(params[9])};
return std::make_shared<Padding>(pad0, pad1, pad2, pad3, val);
std::vector<std::pair<int, int>> paddings = {
{std::stoi(params[2]), std::stoi(params[3])},
{std::stoi(params[4]), std::stoi(params[5])},
{std::stoi(params[6]), std::stoi(params[7])},
{std::stoi(params[8]), std::stoi(params[9])}};
// TODO{fl::Tensor} -- rearrange arguments
return std::make_shared<Padding>(paddings, val);
}

/* ========== TRANSFORMERS ========== */
Expand Down Expand Up @@ -582,12 +585,12 @@ std::shared_ptr<Module> parseLines(
if (params.size() != 2) {
throw std::invalid_argument("Failed parsing - " + line);
}
auto targetType = fl::stringToAfType(params[1]);
auto targetType = fl::stringToDtype(params[1]);
return std::make_shared<PrecisionCast>(targetType);
}

throw std::invalid_argument("Failed parsing - " + line);
return nullptr;
}
} // namespace

} // namespace
6 changes: 4 additions & 2 deletions flashlight/pkg/runtime/common/SequentialBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@

#include "flashlight/fl/contrib/contrib.h"
#include "flashlight/fl/contrib/modules/modules.h"
#include "flashlight/fl/flashlight.h"

namespace fl {

class Tensor;

namespace pkg {
namespace runtime {

Expand All @@ -35,7 +37,7 @@ std::shared_ptr<fl::Sequential> buildSequentialModule(
fl::Variable forwardSequentialModuleWithPadMask(
const fl::Variable& input,
std::shared_ptr<fl::Module> ntwrk,
const af::array& inputSizes);
const Tensor& inputSizes);

} // namespace runtime
} // namespace pkg
Expand Down
41 changes: 0 additions & 41 deletions flashlight/pkg/runtime/common/Utils-inl.h

This file was deleted.

2 changes: 1 addition & 1 deletion flashlight/pkg/runtime/test/DynamicScalerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ TEST(DynamicScalerTest, Scaling) {
100 // updateInterval
);

auto loss = fl::uniform(af::dim4(5, 5, 5, 5));
auto loss = fl::uniform({5, 5, 5, 5});

auto scaledLoss = dynamicScaler.scale(loss);
ASSERT_TRUE(allClose(loss * 32, scaledLoss));
Expand Down
15 changes: 8 additions & 7 deletions flashlight/pkg/runtime/test/common/SequentialBuilderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@

#include <gtest/gtest.h>

#include "flashlight/pkg/runtime/common/SequentialBuilder.h"
#include "flashlight/fl/tensor/Init.h"
#include "flashlight/fl/tensor/Random.h"
#include "flashlight/lib/common/System.h"
#include "flashlight/pkg/runtime/common/SequentialBuilder.h"

using namespace fl;
using namespace fl::pkg::runtime;
Expand All @@ -33,16 +34,16 @@ TEST(SequentialBuilderTest, SeqModule) {

auto model = buildSequentialModule(archfile, nchannel, nclass);

auto input = af::randn(inputsteps, 1, nchannel, batchsize, f32);
auto input = fl::randn({inputsteps, 1, nchannel, batchsize}, fl::dtype::f32);

auto output = model->forward(noGrad(input));

ASSERT_EQ(output.dims(), af::dim4(nclass, inputsteps, batchsize));
ASSERT_EQ(output.dims(), Shape({nclass, inputsteps, batchsize}));

batchsize = 1;
input = af::randn(inputsteps, 1, nchannel, batchsize, f32);
input = fl::randn({inputsteps, 1, nchannel, batchsize}, fl::dtype::f32);
output = model->forward(noGrad(input));
ASSERT_EQ(output.dims(), af::dim4(nclass, inputsteps, batchsize));
ASSERT_EQ(output.dims(), Shape({nclass, inputsteps, batchsize}));
}

TEST(SequentialBuilderTest, Serialization) {
Expand All @@ -60,7 +61,7 @@ TEST(SequentialBuilderTest, Serialization) {
int C = 1, N = 5, B = 1, T = 10;
auto model = buildSequentialModule(archfile, C, N);

auto input = noGrad(af::randn(T, 1, C, B, f32));
auto input = noGrad(fl::randn({T, 1, C, B}, fl::dtype::f32));
auto output = model->forward(input);

save(path, model);
Expand All @@ -71,7 +72,7 @@ TEST(SequentialBuilderTest, Serialization) {
auto outputl = loaded->forward(input);

ASSERT_TRUE(allParamsClose(*loaded.get(), *model));
ASSERT_TRUE(allClose(outputl, output));
ASSERT_TRUE(allClose(outputl.tensor(), output.tensor()));
}

int main(int argc, char** argv) {
Expand Down
Loading

0 comments on commit 55f36a7

Please sign in to comment.