Skip to content

Commit

Permalink
rowwise sparse adagrad fused (pytorch#298)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#298

Implements a JIT kernel used by fused rowwise sparse adagrad operator.

A few more minor changes like
* incorporate BW usage for things other than embedding table in benchmarks
* Error checking when the number of indices is not same as the sum of lengths

Reviewed By: jianyuh

Differential Revision: D19919020

fbshipit-source-id: 078072f6f643112c8758ac9a562c47fbe659478f
  • Loading branch information
jspark1105 committed Mar 21, 2020
1 parent ccd91d6 commit cfda552
Show file tree
Hide file tree
Showing 14 changed files with 1,235 additions and 145 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ set(FBGEMM_GENERIC_SRCS src/EmbeddingSpMDM.cc
src/PackWeightsForConv.cc
src/QuantUtils.cc
src/RefImplementations.cc
src/RowWiseSparseAdagradFused.cc
src/SparseAdagrad.cc
src/Utils.cc)

Expand Down
17 changes: 9 additions & 8 deletions bench/EmbeddingSpMDM8BitBenchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,15 @@ int run_benchmark(

constexpr int NUM_WARMUP = 4;
int NUM_ITER = stress_multi_threading ? 1 << 20 : 10;
// Only counts the number of bytes for reading embedding table and ignore
// others. Should be good enough as long as embdding_dim is big enough.
double bytes =
lengths_sum * (embedding_dim * sizeof(uint8_t) + 2 * sizeof(float));
double bytes_padded =
lengths_sum * 64 *
static_cast<int>(
(embedding_dim * sizeof(uint8_t) + 2 * sizeof(float) + 63) / 64);
double bytes = lengths_sum *
(embedding_dim * sizeof(uint8_t) + 2 * sizeof(float) +
(use_32_bit_indices ? 4 : 8)) +
batch_size * sizeof(int);
double bytes_padded = lengths_sum *
((embedding_dim * sizeof(uint8_t) + 2 * sizeof(float) + 63) / 64 *
64 +
(use_32_bit_indices ? 4 : 8)) +
batch_size * sizeof(int);

vector<bool> has_weight_options;
has_weight_options.push_back(false);
Expand Down
31 changes: 12 additions & 19 deletions bench/EmbeddingSpMDMBenchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ static vector<vector<int>> GetInputs_() {
return input_dims;
}

int run_benchmark(
void run_benchmark(
int batch_size,
int num_rows,
int embedding_dim,
Expand Down Expand Up @@ -112,14 +112,13 @@ int run_benchmark(

constexpr int NUM_WARMUP = 4;
constexpr int NUM_ITER = 10;
// Only counts the number of bytes for reading embedding table and ignore
// others. Should be good enough as long as embdding_dim is big enough.
double bytes =
lengths_sum * (embedding_dim * sizeof(uint8_t) + 2 * sizeof(float));
double bytes_padded =
lengths_sum * 64 *
static_cast<int>(
(embedding_dim * sizeof(uint8_t) + 2 * sizeof(float) + 63) / 64);
double bytes = lengths_sum *
(embedding_dim * sizeof(float) + (use_32_bit_indices ? 4 : 8)) +
batch_size * sizeof(int);
double bytes_padded = lengths_sum *
((embedding_dim * sizeof(float) + 63) / 64 * 64 +
(use_32_bit_indices ? 4 : 8)) +
batch_size * sizeof(int);

for (bool has_weight : {false, true}) {
vector<float>& output_ref = has_weight ? output_slws_ref : output_sls_ref;
Expand Down Expand Up @@ -304,23 +303,17 @@ int run_benchmark(
<< setw(16) << t << endl;
} // flush_cache
} // has_weight
return 0;
}

int main() {
int batch_size;
int num_rows;
int embedding_dim;
int average_len;

vector<vector<int>> inputs(GetInputs_());

for (auto& input : inputs) {
assert(input.size() > 3);
batch_size = input[0];
num_rows = input[1];
embedding_dim = input[2];
average_len = input[3];
int batch_size = input[0];
int num_rows = input[1];
int embedding_dim = input[2];
int average_len = input[3];

cout << "batch size" << setw(6) << batch_size << setw(10) << "num rows"
<< setw(16) << num_rows << setw(10) << "emb dim" << setw(6)
Expand Down
10 changes: 5 additions & 5 deletions bench/RowwiseAdagradBenchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ void run_benchmark(
int block_size, // number of parameters per row
std::uint64_t param_size) { // total number of parameters
vector<char> llc(64L * 1024L * 1024L, 1.0);
vector<float> g(param_size); // gradients
vector<float> h(param_size); // input momentums
vector<float> g(num_rows * block_size); // gradients
vector<float> h(param_size / block_size); // input momentums
vector<float> w(param_size); // input params
vector<float> h_ref(param_size);
vector<float> h_ref(param_size / block_size);
vector<float> w_ref(param_size);

default_random_engine generator;
normal_distribution<float> h_w_distribution;
// normal_distribution<float> h_w_distribution;

// TODO: check appropriate vals for g,h,w
for (int i = 0; i < g.size(); ++i) {
Expand Down Expand Up @@ -104,7 +104,7 @@ int main() {
vector<vector<int>> inputs(GetInputs_());

for (auto& input : inputs) {
assert(input.size() > 2);
assert(input.size() >= 2);
num_rows = input[0];
block_size = input[1];
param_size = num_rows * block_size;
Expand Down
206 changes: 206 additions & 0 deletions bench/RowwiseAdagradFusedBenchmark.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <algorithm>
#include <cassert>
#include <chrono>
#include <cstdint>
#include <iomanip>
#include <iostream>
#include <map>
#include <random>
#include <set>
#include <vector>

#include "./BenchUtils.h"
#include "fbgemm/Fbgemm.h"
#include "src/RefImplementations.h"

using namespace std;
using namespace fbgemm;

static vector<vector<int>> GetInputs_() {
vector<vector<int>> input_dims = {
// batch size, number of rows of table, emb dim , avg lengthl
// TODO: Add more inputs
// Use these -- but they are slow.
{10, 4000000, 32, 100},
{10, 4000000, 64, 100},
{10, 4000000, 128, 100},
{10, 4000000, 256, 100},
// Use these for debugging
// {2, 16, 128, 10},
// {10, 4000, 128, 100},
// {10, 4000, 128, 100},
// {10, 4000, 128, 100},
};
return input_dims;
}

void run_benchmark(
int batch_size,
int num_rows,
int embedding_dim,
int average_len,
bool use_32_bit_indices = false,
bool prefetch = false) {
vector<char> llc(64L * 1024L * 1024L, 1.0);
vector<float> g(batch_size * embedding_dim); // gradients
vector<float> h(num_rows); // input momentums
vector<float> w(num_rows * embedding_dim); // input params
vector<float> h_ref(h.size());
vector<float> w_ref(w.size());

default_random_engine generator;
// normal_distribution<float> h_w_distribution;

// TODO: check appropriate vals for g,h,w
for (int i = 0; i < g.size(); ++i) {
g[i] = 4 + i; // h_w_distribution(generator);
}
for (int i = 0; i < h.size(); ++i) {
h_ref[i] = h[i] = 2 + i; // h_w_distribution(generator);
}
for (int i = 0; i < w.size(); ++i) {
w_ref[i] = w[i] = 3 + i; // h_w_distribution(generator);
}

// Generate lengths
uniform_int_distribution<int> length_distribution(
1, std::min(2 * average_len + 1, num_rows));
vector<int> lengths(batch_size);
for (int i = 0; i < batch_size; ++i) {
lengths[i] = length_distribution(generator);
}

// Compute the number of indices
int lengths_sum = accumulate(lengths.begin(), lengths.end(), 0);
cout << "lengths_sum " << lengths_sum << endl;

// Generate indices
vector<int64_t> indices;
vector<int32_t> indices_32;

vector<int> container(num_rows);

// please note we generate unique indices
for (int i = 0; i < batch_size; ++i) {
iota(container.begin(), container.end(), 0);
random_shuffle(container.begin(), container.end());
copy(
container.begin(),
container.begin() + lengths[i],
back_inserter(indices));
}
copy(begin(indices), end(indices), back_inserter(indices_32));

float epsilon = 1e-5;
float lr = 0.5;

constexpr int NUM_WARMUP = 4;
constexpr int NUM_ITER = 10;
// Only counts the number of bytes for reading embedding table and ignore
// others. Should be good enough as long as embdding_dim is big enough.
double bytes = lengths_sum *
((embedding_dim + 1) * sizeof(float) * 2 +
(use_32_bit_indices ? 4 : 8)) +
batch_size * (embedding_dim * sizeof(float) + sizeof(int));
double bytes_padded = lengths_sum *
(((embedding_dim * sizeof(float) + 63) / 64 + 1) * 64 * 2 +
(use_32_bit_indices ? 4 : 8)) +
batch_size * (embedding_dim * sizeof(float) + sizeof(int));

auto kernel_i32 = GenerateRowWiseSparseAdaGradFused<int32_t>(
embedding_dim, prefetch ? 16 : 0);
auto kernel_i64 = GenerateRowWiseSparseAdaGradFused<int64_t>(
embedding_dim, prefetch ? 16 : 0);

for (bool flush_cache : {false, true}) {
double t = measureWithWarmup(
[&]() {
if (use_32_bit_indices) {
kernel_i32(
batch_size,
lengths_sum,
num_rows,
w.data(),
g.data(),
h.data(),
indices_32.data(),
lengths.data(),
epsilon,
lr);
} else {
kernel_i64(
batch_size,
lengths_sum,
num_rows,
w.data(),
g.data(),
h.data(),
indices.data(),
lengths.data(),
epsilon,
lr);
}
},
NUM_WARMUP,
NUM_ITER,
[&]() { llc_flush(llc); });

if (flush_cache) {
cout << setw(20) << "cache flushed";
} else {
cout << setw(20) << "cache not flushed";
}
if (prefetch) {
cout << setw(16) << "prefetch on";
} else {
cout << setw(16) << "prefetch off";
}

cout << setw(8) << "b/w" << setw(10) << bytes / 1e9 / t << " GB/s"
<< setw(20) << "effective b/w: " << setw(16) << bytes_padded / 1e9 / t
<< "GB/s" << setw(8) << " time " << setw(16) << t << endl;
}
}

int main() {
vector<vector<int>> inputs(GetInputs_());

for (auto& input : inputs) {
assert(input.size() > 3);
int batch_size = input[0];
int num_rows = input[1];
int embedding_dim = input[2];
int average_len = input[3];

cout << "batch size" << setw(6) << batch_size << setw(10) << "num rows"
<< setw(16) << num_rows << setw(10) << "emb dim" << setw(6)
<< embedding_dim << setw(16) << "avg length" << setw(6) << average_len
<< endl;

for (bool use_32_bit_indices : {false, true}) {
for (bool prefetch : {false, true}) {
// args: batch sz, num rows, emb dim, avg len, use 32b, prefetch
cout << (use_32_bit_indices ? " 32" : " 64") << " bit indices";
if (prefetch) {
cout << " with prefetching";
}
cout << ", ";
run_benchmark(
batch_size,
num_rows,
embedding_dim,
average_len,
use_32_bit_indices,
prefetch);
} // prefetch
} // use_32_bit_indices
} // for each input

return 0;
}
Loading

0 comments on commit cfda552

Please sign in to comment.