Skip to content

Commit

Permalink
attach eviction filling logic to set_cache (pytorch#3034)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#3034

X-link: facebookresearch/FBGEMM#132

- add eviction callback into cachelib
- add eviction handling logic in cachelib wrapper and kv db tbe

Reviewed By: q10

Differential Revision: D61200308

fbshipit-source-id: 879dbced0248461949f36a53e649817879e921c7
  • Loading branch information
duduyi2013 authored and facebook-github-bot committed Sep 3, 2024
1 parent 886e5db commit 8346a7d
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <ATen/ATen.h>
#include <cachelib/allocator/CacheAllocator.h>
#include <cachelib/facebook/admin/CacheAdmin.h>
#include "deeplearning/fbgemm/fbgemm_gpu/include/fbgemm_gpu/utils/dispatch_macros.h"

#include <cstdint>
#include <iostream>
Expand Down Expand Up @@ -49,8 +50,33 @@ class CacheLibCache {
}

std::unique_ptr<Cache> initializeCacheLib(const CacheConfig& config) {
auto eviction_cb =
[this](const facebook::cachelib::LruAllocator::RemoveCbData& data) {
FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE(
evicted_weights_ptr_->scalar_type(), "l2_eviction_handling", [&] {
if (data.context ==
facebook::cachelib::RemoveContext::kEviction) {
auto indices_data_ptr =
evicted_indices_ptr_->data_ptr<int64_t>();
auto weights_data_ptr =
evicted_weights_ptr_->data_ptr<scalar_t>();
auto row_id = eviction_row_id++;
auto weight_dim = evicted_weights_ptr_->size(1);
const auto key_ptr = reinterpret_cast<const int64_t*>(
data.item.getKey().data());
indices_data_ptr[row_id] = *key_ptr;

std::copy(
reinterpret_cast<const scalar_t*>(data.item.getMemory()),
reinterpret_cast<const scalar_t*>(data.item.getMemory()) +
weight_dim,
&weights_data_ptr[row_id * weight_dim]); // dst_start
}
});
};
Cache::Config cacheLibConfig;
cacheLibConfig.setCacheSize(static_cast<uint64_t>(config.cacheSizeBytes))
.setRemoveCallback(eviction_cb)
.setCacheName("TBEL2Cache")
.setAccessConfig({25 /* bucket power */, 10 /* lock power */})
.setFullCoredump(false)
Expand Down Expand Up @@ -132,11 +158,62 @@ class CacheLibCache {
return true;
}

/// instantiate eviction related indices and weights tensors(size of <count>)
/// for L2 eviction using the same dtype and device from <indices> and
/// <weights> , managed on the caller side
///
/// @param indices The 1D embedding index tensor, should skip on negative
/// value
/// @param weights The 2D tensor that each row(embeddings) is paired up with
/// relative element in <indices>
/// @param count A single element tensor that contains the number of indices
/// to be processed
///
/// @return None
void init_tensor_for_l2_eviction(
const at::Tensor& indices,
const at::Tensor& weights,
const at::Tensor& count) {
auto num_lookups = count.item<long>();
evicted_indices_ptr_ = std::make_shared<at::Tensor>(
at::ones(
num_lookups,
at::TensorOptions()
.device(indices.device())
.dtype(indices.dtype())) *
-1);
evicted_weights_ptr_ = std::make_shared<at::Tensor>(at::empty(
{num_lookups, weights.size(1)},
at::TensorOptions().device(weights.device()).dtype(weights.dtype())));
}

/// reset slot pointer that points to the next available slot in the eviction
/// tensors
void reset_eviction_states() {
eviction_row_id = 0;
}

/// get the filled indices and weights tensors from L2 eviction, could be all
/// invalid if no eviction happened
folly::Optional<std::pair<at::Tensor, at::Tensor>>
get_evicted_indices_and_weights() {
if (evicted_indices_ptr_) {
assert(evicted_weights_ptr_ != nullptr);
return std::make_pair(*evicted_indices_ptr_, *evicted_weights_ptr_);
} else {
return folly::none;
}
}

private:
const CacheConfig cacheConfig_;
std::unique_ptr<Cache> cache_;
std::vector<facebook::cachelib::PoolId> pool_ids_;
std::unique_ptr<facebook::cachelib::CacheAdmin> admin_;

std::shared_ptr<at::Tensor> evicted_indices_ptr_{nullptr};
std::shared_ptr<at::Tensor> evicted_weights_ptr_{nullptr};
std::atomic<int64_t> eviction_row_id{0};
};

} // namespace l2_cache
Original file line number Diff line number Diff line change
Expand Up @@ -279,16 +279,16 @@ void EmbeddingKVDB::wait_util_filling_work_done() {
facebook::WallClockUtil::NowInUsecFast() - start_ts;
}

void EmbeddingKVDB::set_cache(
folly::Optional<std::pair<at::Tensor, at::Tensor>> EmbeddingKVDB::set_cache(
const at::Tensor& indices,
const at::Tensor& weights,
const at::Tensor& count) {
// caller's responsibility to make sure l2_cache_ exists

// TODO: consider whether need to reconstruct indices/weights/count and free
// the original tensor since most of the tensor elem will be invalid,
// this will trade some perf for peak DRAM util saving
if (l2_cache_ == nullptr) {
return;
}
l2_cache_->init_tensor_for_l2_eviction(indices, weights, count);
auto indices_addr = indices.data_ptr<int64_t>();
auto num_lookups = count.item<long>();
for (auto i = 0; i < num_lookups; i++) {
Expand All @@ -300,6 +300,8 @@ void EmbeddingKVDB::set_cache(
<< "]Failed to insert into cache, this shouldn't happen";
}
}
l2_cache_->reset_eviction_states();
return l2_cache_->get_evicted_indices_and_weights();
}

folly::coro::Task<void> EmbeddingKVDB::cache_memcpy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,10 @@ class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB> {
/// @param count A single element tensor that contains the number of indices
/// to be processed
///
/// @return pair of tensors with length of <count> containing L2 evicted
/// embedding indices and embeddings, invalid pairs will have
/// sentinel value(-1) on <indices>
void set_cache(
/// @return None if L2 is missing, other wise return pair of tensors with
/// length of <count> containing L2 evicted embedding indices and embeddings,
/// invalid pairs will have sentinel value(-1) on <indices>
folly::Optional<std::pair<at::Tensor, at::Tensor>> set_cache(
const at::Tensor& indices,
const at::Tensor& weights,
const at::Tensor& count);
Expand Down

0 comments on commit 8346a7d

Please sign in to comment.