Skip to content

Commit

Permalink
Implement PS KV DB for FBGEMM TBE operator (pytorch#2664)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2664

Implemented kv_db::EmbeddingKVDB interface to connect remote PS service instead of using local SSD based rocksDB embedding.
Initialize PS KV DB when ps_hosts is not None.

Reviewed By: sryap

Differential Revision: D56715840

fbshipit-source-id: 53ebb514bceb21ee4c124afed46907875f9e1750
  • Loading branch information
emlin authored and facebook-github-bot committed Jun 4, 2024
1 parent eb7b841 commit 4449cbc
Show file tree
Hide file tree
Showing 3 changed files with 238 additions and 41 deletions.
126 changes: 85 additions & 41 deletions fbgemm_gpu/fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
WeightDecayMode,
)

from torch import nn, Tensor # usort:skip
from torch import distributed as dist, nn, Tensor # usort:skip
from torch.autograd.profiler import record_function

try:
Expand Down Expand Up @@ -80,6 +80,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
weights_host: Tensor
weights_placements: Tensor
weights_offsets: Tensor
_local_instance_index: int = -1

def __init__(
self,
Expand Down Expand Up @@ -123,6 +124,8 @@ def __init__(
CowClipDefinition
] = None, # used by Rowwise Adagrad
pooling_mode: PoolingMode = PoolingMode.SUM,
ps_hosts: Optional[Tuple[Tuple[str, int]]] = None,
tbe_unique_id: int = -1,
) -> None:
super(SSDTableBatchedEmbeddingBags, self).__init__()

Expand Down Expand Up @@ -240,26 +243,45 @@ def __init__(
ssd_directory = tempfile.mkdtemp(
prefix="ssd_table_batched_embeddings", dir=ssd_storage_directory
)
# pyre-fixme[4]: Attribute must be annotated.
# pyre-ignore[16]
self.ssd_db = torch.classes.fbgemm.EmbeddingRocksDBWrapper(
ssd_directory,
ssd_shards,
ssd_shards,
ssd_memtable_flush_period,
ssd_memtable_flush_offset,
ssd_l0_files_per_compact,
self.max_D,
ssd_rate_limit_mbps,
ssd_size_ratio,
ssd_compaction_trigger,
ssd_write_buffer_size,
ssd_max_write_buffer_num,
ssd_uniform_init_lower,
ssd_uniform_init_upper,
weights_precision.bit_rate(), # row_storage_bitwidth
ssd_block_cache_size,
)
# logging.info("DEBUG: weights_precision {}".format(weights_precision))
if not ps_hosts:
# pyre-fixme[4]: Attribute must be annotated.
# pyre-ignore[16]
self.ssd_db = torch.classes.fbgemm.EmbeddingRocksDBWrapper(
ssd_directory,
ssd_shards,
ssd_shards,
ssd_memtable_flush_period,
ssd_memtable_flush_offset,
ssd_l0_files_per_compact,
self.max_D,
ssd_rate_limit_mbps,
ssd_size_ratio,
ssd_compaction_trigger,
ssd_write_buffer_size,
ssd_max_write_buffer_num,
ssd_uniform_init_lower,
ssd_uniform_init_upper,
weights_precision.bit_rate(), # row_storage_bitwidth
ssd_block_cache_size,
)
else:
# create tbe unique id using rank index | pooling mode
if tbe_unique_id == -1:
self._local_instance_index += 1
assert (
self._local_instance_index < 8
), "More than 8 TBE instance is created in one rank, the tbe unique id won't be unique in this case."
tbe_unique_id = dist.get_rank() << 3 | self._local_instance_index
# pyre-fixme[4]: Attribute must be annotated.
# pyre-ignore[16]
self.ssd_db = torch.classes.fbgemm.EmbeddingParameterServerWrapper(
[host[0] for host in ps_hosts],
[host[1] for host in ps_hosts],
tbe_unique_id,
54,
32,
)
# pyre-fixme[20]: Argument `self` expected.
(low_priority, high_priority) = torch.cuda.Stream.priority_range()
self.ssd_stream = torch.cuda.Stream(priority=low_priority)
Expand Down Expand Up @@ -705,6 +727,7 @@ class SSDIntNBitTableBatchedEmbeddingBags(nn.Module):
"""

embedding_specs: List[Tuple[str, int, int, SparseType]]
_local_instance_index: int = -1

def __init__(
self,
Expand Down Expand Up @@ -733,6 +756,8 @@ def __init__(
ssd_cache_location: EmbeddingLocation = EmbeddingLocation.MANAGED,
ssd_uniform_init_lower: float = -0.01,
ssd_uniform_init_upper: float = 0.01,
ps_hosts: Optional[Tuple[Tuple[str, int]]] = None,
tbe_unique_id: int = -1, # unique id for this embedding, if not set, will derive based on current rank and tbe index id
) -> None: # noqa C901 # tuple of (rows, dims,)
super(SSDIntNBitTableBatchedEmbeddingBags, self).__init__()

Expand Down Expand Up @@ -906,26 +931,45 @@ def max_ty_D(ty: SparseType) -> int:
ssd_directory = tempfile.mkdtemp(
prefix="ssd_table_batched_embeddings", dir=ssd_storage_directory
)
# pyre-fixme[4]: Attribute must be annotated.
# pyre-ignore[16]
self.ssd_db = torch.classes.fbgemm.EmbeddingRocksDBWrapper(
ssd_directory,
ssd_shards,
ssd_shards,
ssd_memtable_flush_period,
ssd_memtable_flush_offset,
ssd_l0_files_per_compact,
self.max_D_cache,
ssd_rate_limit_mbps,
ssd_size_ratio,
ssd_compaction_trigger,
ssd_write_buffer_size,
ssd_max_write_buffer_num,
ssd_uniform_init_lower,
ssd_uniform_init_upper,
8, # row_storage_bitwidth
0, # ssd_block_cache_size
)
if not ps_hosts:
# pyre-fixme[4]: Attribute must be annotated.
# pyre-ignore[16]
self.ssd_db = torch.classes.fbgemm.EmbeddingRocksDBWrapper(
ssd_directory,
ssd_shards,
ssd_shards,
ssd_memtable_flush_period,
ssd_memtable_flush_offset,
ssd_l0_files_per_compact,
self.max_D_cache,
ssd_rate_limit_mbps,
ssd_size_ratio,
ssd_compaction_trigger,
ssd_write_buffer_size,
ssd_max_write_buffer_num,
ssd_uniform_init_lower,
ssd_uniform_init_upper,
8, # row_storage_bitwidth
0, # ssd_block_cache_size
)
else:
# create tbe unique id using rank index | pooling mode
if tbe_unique_id == -1:
self._local_instance_index += 1
assert (
self._local_instance_index < 8
), "More than 8 TBE instance is created in one rank, the tbe unique id won't be unique in this case."
tbe_unique_id = dist.get_rank() << 3 | self._local_instance_index
# pyre-fixme[4]: Attribute must be annotated.
# pyre-ignore[16]
self.ssd_db = torch.classes.fbgemm.EmbeddingParameterServerWrapper(
[host[0] for host in ps_hosts],
[host[1] for host in ps_hosts],
tbe_unique_id,
54,
32,
)

# pyre-fixme[20]: Argument `self` expected.
(low_priority, high_priority) = torch.cuda.Stream.priority_range()
self.ssd_stream = torch.cuda.Stream(priority=low_priority)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "./ps_table_batched_embeddings.h"

#include <torch/custom_class.h>
#include "fbgemm_gpu/sparse_ops_utils.h"

using namespace at;
using namespace ps;

namespace {
class EmbeddingParameterServerWrapper : public torch::jit::CustomClassHolder {
public:
EmbeddingParameterServerWrapper(
const std::vector<std::string>& tps_ips,
const std::vector<int64_t>& tps_ports,
int64_t tbe_id,
int64_t maxLocalIndexLength = 54,
int64_t num_threads = 32) {
TORCH_CHECK(
tps_ips.size() == tps_ports.size(),
"tps_ips and tps_ports must have the same size");
std::vector<std::pair<std::string, int>> tpsHosts = {};
for (int i = 0; i < tps_ips.size(); i++) {
tpsHosts.push_back(std::make_pair(tps_ips[i], tps_ports[i]));
}

impl_ = std::make_shared<ps::EmbeddingParameterServer>(
std::move(tpsHosts), tbe_id, maxLocalIndexLength, num_threads);
}

void
set_cuda(Tensor indices, Tensor weights, Tensor count, int64_t timestep) {
return impl_->set_cuda(indices, weights, count, timestep);
}

void get_cuda(Tensor indices, Tensor weights, Tensor count) {
return impl_->get_cuda(indices, weights, count);
}

void set(Tensor indices, Tensor weights, Tensor count) {
return impl_->set(indices, weights, count);
}

void get(Tensor indices, Tensor weights, Tensor count) {
return impl_->get(indices, weights, count);
}

void compact() {
return impl_->compact();
}

void flush() {
return impl_->flush();
}

void cleanup() {
return impl_->cleanup();
}

private:
// shared pointer since we use shared_from_this() in callbacks.
std::shared_ptr<EmbeddingParameterServer> impl_;
};

static auto embedding_parameter_server_wrapper =
torch::class_<EmbeddingParameterServerWrapper>(
"fbgemm",
"EmbeddingParameterServerWrapper")
.def(torch::init<
const std::vector<std::string>,
const std::vector<int64_t>,
int64_t,
int64_t,
int64_t>())
.def("set_cuda", &EmbeddingParameterServerWrapper::set_cuda)
.def("get_cuda", &EmbeddingParameterServerWrapper::get_cuda)
.def("compact", &EmbeddingParameterServerWrapper::compact)
.def("flush", &EmbeddingParameterServerWrapper::flush)
.def("set", &EmbeddingParameterServerWrapper::set)
.def("get", &EmbeddingParameterServerWrapper::get)
.def("cleanup", &EmbeddingParameterServerWrapper::cleanup);
} // namespace
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once
#include "../ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h"

#include <folly/experimental/coro/BlockingWait.h>
#include "mvai_infra/experimental/ps_training/tps_client/TrainingParameterServiceClient.h"

namespace ps {

class EmbeddingParameterServer : public kv_db::EmbeddingKVDB {
public:
explicit EmbeddingParameterServer(
std::vector<std::pair<std::string, int>>&& tps_hosts,
int64_t tbe_id,
int64_t maxLocalIndexLength = 54,
int64_t num_threads = 32)
: tps_client_(
std::make_shared<mvai_infra::experimental::ps_training::tps_client::
TrainingParameterServiceClient>(
std::move(tps_hosts),
tbe_id,
maxLocalIndexLength,
num_threads)) {}

void set(
const at::Tensor& indices,
const at::Tensor& weights,
const at::Tensor& count) override {
RECORD_USER_SCOPE("EmbeddingParameterServer::set");
folly::coro::blockingWait(
tps_client_->set(indices, weights, count.item().toLong()));
}
void get(
const at::Tensor& indices,
const at::Tensor& weights,
const at::Tensor& count) override {
RECORD_USER_SCOPE("EmbeddingParameterServer::get");
folly::coro::blockingWait(
tps_client_->get(indices, weights, count.item().toLong()));
}
void flush() override {}
void compact() override {}
// cleanup cached results in server side
// This is a test helper, please do not use it in production
void cleanup() {
folly::coro::blockingWait(tps_client_->cleanup());
}

private:
void flush_or_compact(const int64_t /*timestep*/) override {}

std::shared_ptr<mvai_infra::experimental::ps_training::tps_client::
TrainingParameterServiceClient>
tps_client_;
}; // class EmbeddingKVDB

} // namespace ps

0 comments on commit 4449cbc

Please sign in to comment.