Skip to content

Commit

Permalink
Change pkg/speech/runtime to use fl::Tensor (flashlight#794)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: flashlight#794

Migrate `pkg/speech/runtime` to use `fl::Tensor`.

Reviewed By: benoitsteiner

Differential Revision: D31984596

fbshipit-source-id: cc38e3dcabdabafe63c1e0bc2c1e9fa6593421db
  • Loading branch information
jacobkahn authored and facebook-github-bot committed May 3, 2022
1 parent 16de4f4 commit fce237c
Show file tree
Hide file tree
Showing 9 changed files with 46 additions and 53 deletions.
3 changes: 2 additions & 1 deletion flashlight/pkg/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,12 @@ if(FL_BUILD_PKG_TEXT)
endif()

# --------------------------- Speech ---------------------------
# TODO(jacobkahn): remove the dependency on the runtime package
fl_dependent_option(
FL_BUILD_PKG_SPEECH
"Build speech package for Flashlight"
"${FL_BUILD_ALL_PKGS}"
"FL_BUILD_CORE;FL_BUILD_CONTRIB;FL_BUILD_LIB_COMMON;FL_BUILD_LIB_AUDIO;FL_BUILD_LIB_TEXT;FL_BUILD_LIB_SEQUENCE"
"FL_BUILD_CORE;FL_BUILD_CONTRIB;FL_BUILD_LIB_COMMON;FL_BUILD_LIB_AUDIO;FL_BUILD_LIB_TEXT;FL_BUILD_LIB_SEQUENCE;FL_BUILD_PKG_RUNTIME"
OFF)

if(FL_BUILD_PKG_SPEECH)
Expand Down
2 changes: 1 addition & 1 deletion flashlight/pkg/runtime/test/common/arch.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ L 16 32
RO 2 1 0 3
RO 2 3 0 1
GRU 32 256 3 1
RO 0 2 1 3
RO 0 2 1
L 512 NLABEL
33 changes: 16 additions & 17 deletions flashlight/pkg/speech/runtime/Helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#include "deeplearning/projects/flashlight/fb/EverstoreDataset.h"
#endif

using fl::pkg::runtime::afToVector;
using fl::lib::format;
using fl::lib::getCurrentDate;
using fl::lib::getCurrentTime;
Expand All @@ -35,11 +34,11 @@ namespace pkg {
namespace speech {

template <class T>
std::vector<std::string> afMatrixToStrings(const af::array& arr, T terminator) {
int L = arr.dims(0); // padded length of string
int N = arr.dims(1); // number of strings
std::vector<std::string> afMatrixToStrings(const Tensor& tensor, T terminator) {
int L = tensor.dim(0); // padded length of string
int N = tensor.dim(1); // number of strings
std::vector<std::string> result;
auto values = afToVector<T>(arr);
auto values = tensor.toHostVector<T>();
for (int i = 0; i < N; ++i) {
const T* row = &values[i * L];
int len = 0;
Expand Down Expand Up @@ -96,8 +95,8 @@ getTrainEvalIds(int64_t dsSize, double pctTrainEval, int64_t seed) {
return result;
}

std::vector<std::string> readSampleIds(const af::array& arr) {
return afMatrixToStrings<char>(arr, '\0');
std::vector<std::string> readSampleIds(const Tensor& tensor) {
return afMatrixToStrings<char>(tensor, '\0');
}

std::shared_ptr<fl::Dataset> createDataset(
Expand Down Expand Up @@ -171,19 +170,19 @@ std::shared_ptr<fl::Dataset> createDataset(
int inPad, tgtPad, wrdPad;
std::tie(inPad, tgtPad, wrdPad) = padVal;
auto batchFns = std::vector<fl::Dataset::BatchFunction>{
[inPad](const std::vector<af::array>& arr) {
return fl::join(arr, inPad, 3);
[inPad](const std::vector<Tensor>& tensor) {
return fl::join(tensor, inPad, 3);
},
[tgtPad](const std::vector<af::array>& arr) {
return fl::join(arr, tgtPad, 1);
[tgtPad](const std::vector<Tensor>& tensor) {
return fl::join(tensor, tgtPad, 1);
},
[wrdPad](const std::vector<af::array>& arr) {
return fl::join(arr, wrdPad, 1);
[wrdPad](const std::vector<Tensor>& tensor) {
return fl::join(tensor, wrdPad, 1);
},
[](const std::vector<af::array>& arr) { return fl::join(arr, 0, 1); },
[](const std::vector<af::array>& arr) { return fl::join(arr, 0, 1); },
[](const std::vector<af::array>& arr) { return fl::join(arr, 0, 1); },
[](const std::vector<af::array>& arr) { return fl::join(arr, 0, 1); }};
[](const std::vector<Tensor>& tensor) { return fl::join(tensor, 0, 1); },
[](const std::vector<Tensor>& tensor) { return fl::join(tensor, 0, 1); },
[](const std::vector<Tensor>& tensor) { return fl::join(tensor, 0, 1); },
[](const std::vector<Tensor>& tensor) { return fl::join(tensor, 0, 1); }};
if (batchingStrategy == kBatchStrategyDynamic ||
batchingStrategy == kBatchStrategyRandDynamic) {
// Partition the dataset and distribute
Expand Down
4 changes: 2 additions & 2 deletions flashlight/pkg/speech/runtime/Helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ std::unordered_set<int64_t>
getTrainEvalIds(int64_t dsSize, double pctTrainEval, int64_t seed);

/**
* Read sample ids from an `af::array`.
* Read sample ids from an `Tensor`.
*/
std::vector<std::string> readSampleIds(const af::array& arr);
std::vector<std::string> readSampleIds(const Tensor& arr);

/*
* Utility function for creating a w2l dataset.
Expand Down
15 changes: 7 additions & 8 deletions flashlight/pkg/speech/runtime/Logger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@

#include <thread>

#include "flashlight/pkg/speech/common/Defines.h"
#include "flashlight/pkg/speech/common/Flags.h"
#include "flashlight/pkg/runtime/common/DistributedUtils.h"
#include "flashlight/lib/common/String.h"
#include "flashlight/lib/common/System.h"
#include "flashlight/pkg/runtime/common/DistributedUtils.h"
#include "flashlight/pkg/speech/common/Defines.h"
#include "flashlight/pkg/speech/common/Flags.h"

using fl::pkg::runtime::afToVector;
using fl::lib::format;
using fl::lib::getCurrentDate;
using fl::lib::getCurrentTime;
Expand Down Expand Up @@ -113,19 +112,19 @@ void appendToLog(std::ofstream& logfile, const std::string& logstr) {
retryWithBackoff(std::chrono::seconds(1), 1.0, 6, write);
}

af::array allreduceGet(SpeechStatMeter& mtr) {
Tensor allreduceGet(SpeechStatMeter& mtr) {
auto mtrValRaw = mtr.value();
std::vector<long long> mtrVal(mtrValRaw.begin(), mtrValRaw.end());
// Caveat: maxInputSz_, maxTargetSz_ would be approximate
mtrVal[2] *= mtrVal[4];
mtrVal[3] *= mtrVal[4];
return af::array(mtrVal.size(), mtrVal.data());
return Tensor::fromVector(mtrVal);
}

void allreduceSet(SpeechStatMeter& mtr, af::array& val) {
void allreduceSet(SpeechStatMeter& mtr, Tensor& val) {
mtr.reset();
// Caveat: maxInputSz_, maxTargetSz_ would be approximate
auto valVec = afToVector<int64_t>(val);
auto valVec = val.toHostVector<int64_t>();
SpeechStats stats;
auto denom = (valVec[4] == 0) ? 1 : valVec[4];
stats.totalInputSz_ = valVec[0];
Expand Down
4 changes: 2 additions & 2 deletions flashlight/pkg/speech/runtime/Logger.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ std::string getLogString(

void appendToLog(std::ofstream& logfile, const std::string& logstr);

af::array allreduceGet(SpeechStatMeter& mtr);
void allreduceSet(SpeechStatMeter& mtr, af::array& val);
Tensor allreduceGet(SpeechStatMeter& mtr);
void allreduceSet(SpeechStatMeter& mtr, Tensor& val);

void syncMeter(TrainMeters& mtrs);
} // namespace speech
Expand Down
14 changes: 6 additions & 8 deletions flashlight/pkg/speech/runtime/SpeechStatMeter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,19 @@ void SpeechStatMeter::reset() {
stats_.reset();
}

void SpeechStatMeter::add(
const af::array& inputSizes,
const af::array& targetSizes) {
int64_t curInputSz = af::sum<int64_t>(inputSizes);
int64_t curTargetSz = af::sum<int64_t>(targetSizes);
void SpeechStatMeter::add(const Tensor& inputSizes, const Tensor& targetSizes) {
int64_t curInputSz = fl::sum(inputSizes).asScalar<int64_t>();
int64_t curTargetSz = fl::sum(targetSizes).asScalar<int64_t>();

stats_.totalInputSz_ += curInputSz;
stats_.totalTargetSz_ += curTargetSz;

stats_.maxInputSz_ =
std::max(stats_.maxInputSz_, af::max<int64_t>(inputSizes));
std::max(stats_.maxInputSz_, fl::amax(inputSizes).asScalar<int64_t>());
stats_.maxTargetSz_ =
std::max(stats_.maxTargetSz_, af::max<int64_t>(targetSizes));
std::max(stats_.maxTargetSz_, fl::amax(targetSizes).asScalar<int64_t>());

stats_.numSamples_ += inputSizes.dims(1);
stats_.numSamples_ += inputSizes.dim(1);
stats_.numBatches_++;
}

Expand Down
2 changes: 1 addition & 1 deletion flashlight/pkg/speech/runtime/SpeechStatMeter.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ struct SpeechStats {
class SpeechStatMeter {
public:
SpeechStatMeter();
void add(const af::array& inputSizes, const af::array& targetSizes);
void add(const Tensor& inputSizes, const Tensor& targetSizes);
void add(const SpeechStats& stats);
std::vector<int64_t> value() const;
void reset();
Expand Down
22 changes: 9 additions & 13 deletions flashlight/pkg/speech/test/runtime/RuntimeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include "flashlight/pkg/speech/runtime/runtime.h"
#include "flashlight/pkg/runtime/common/Serializer.h"
#include "flashlight/fl/flashlight.h"
#include "flashlight/lib/common/System.h"
#include "flashlight/pkg/runtime/common/Serializer.h"
#include "flashlight/pkg/speech/runtime/runtime.h"

using namespace fl::pkg::speech;
using fl::Tensor;

namespace {
const std::string kPath = fl::lib::getTmpPath("test.mdl");
Expand All @@ -26,13 +27,7 @@ bool afEqual(const fl::Variable& a, const fl::Variable& b) {
if (a.isCalcGrad() != b.isCalcGrad()) {
return false;
}
if (a.dims() != b.dims()) {
return false;
}
if (a.array().isempty() && b.array().isempty()) {
return true;
}
return af::allTrue<bool>(af::abs(a.array() - b.array()) < 1E-7);
return allClose(a.tensor(), b.tensor(), 1E-7);
}

} // namespace
Expand Down Expand Up @@ -65,7 +60,7 @@ TEST(RuntimeTest, LoadAndSave) {
modelload.eval();

for (int i = 0; i < 10; ++i) {
auto in = fl::Variable(af::randu(10, 1, 4), i & 1);
auto in = fl::Variable(fl::rand({10, 1, 4, 1}), i & 1);
ASSERT_TRUE(afEqual(model.forward(in), modelload.forward(in)));
}
}
Expand All @@ -86,8 +81,8 @@ TEST(RuntimeTest, SpeechStatMeter) {
std::array<int, 4> inpSizes2{2, 4, 2, 8};
std::array<int, 4> tgSizes2{3, 7, 2, 4};
meter.add(
af::array(1, 2, inpSizes1.data()), af::array(1, 2, tgSizes1.data()));
af::array out;
Tensor::fromArray({1, 2}, inpSizes1),
Tensor::fromArray({1, 2}, tgSizes1));
auto stats1 = meter.value();
ASSERT_EQ(stats1[0], 9.0);
ASSERT_EQ(stats1[1], 16.0);
Expand All @@ -96,7 +91,8 @@ TEST(RuntimeTest, SpeechStatMeter) {
ASSERT_EQ(stats1[4], 2.0);
ASSERT_EQ(stats1[5], 1);
meter.add(
af::array(1, 4, inpSizes2.data()), af::array(1, 4, tgSizes2.data()));
Tensor::fromArray({1, 4}, inpSizes2),
Tensor::fromArray({1, 4}, tgSizes2));
auto stats2 = meter.value();
ASSERT_EQ(stats2[0], 25.0);
ASSERT_EQ(stats2[1], 32.0);
Expand Down

0 comments on commit fce237c

Please sign in to comment.