Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【GPUPS】add env for gpups and fix cache table #70301

Merged
merged 2 commits into from
Dec 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion paddle/fluid/distributed/ps/service/brpc_ps_client.cc
Original file line number Diff line number Diff line change
@@ -439,7 +439,9 @@ int FlClientBrpcClosure::check_response(size_t request_idx, int cmd_id) {
return 0;
}

std::future<int32_t> BrpcPsClient::PrintTableStat(uint32_t table_id) {
std::future<int32_t> BrpcPsClient::PrintTableStat(uint32_t table_id,
uint16_t pass_id,
size_t threshold) {
size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [request_call_num, table_id](void *done) {
4 changes: 3 additions & 1 deletion paddle/fluid/distributed/ps/service/brpc_ps_client.h
Original file line number Diff line number Diff line change
@@ -273,7 +273,9 @@ class BrpcPsClient : public PSClient {
size_t num,
bool is_training);

virtual std::future<int32_t> PrintTableStat(uint32_t table_id);
virtual std::future<int32_t> PrintTableStat(uint32_t table_id,
uint16_t pass_id,
size_t threshold);

virtual std::future<int32_t> Barrier(size_t table_id, uint32_t barrier_type);

4 changes: 3 additions & 1 deletion paddle/fluid/distributed/ps/service/ps_client.h
Original file line number Diff line number Diff line change
@@ -181,7 +181,9 @@ class PSClient {
return fut;
}

virtual std::future<int32_t> PrintTableStat(uint32_t table_id) = 0;
virtual std::future<int32_t> PrintTableStat(uint32_t table_id,
uint16_t pass_id,
size_t threshold) = 0;
virtual std::future<int32_t> SaveCacheTable(uint32_t table_id UNUSED,
uint16_t pass_id UNUSED,
size_t threshold UNUSED) {
9 changes: 8 additions & 1 deletion paddle/fluid/distributed/ps/service/ps_local_client.cc
Original file line number Diff line number Diff line change
@@ -256,11 +256,18 @@ ::std::future<int32_t> PsLocalClient::PullSparsePtr(
return done();
}

::std::future<int32_t> PsLocalClient::PrintTableStat(uint32_t table_id) {
::std::future<int32_t> PsLocalClient::PrintTableStat(uint32_t table_id,
uint16_t pass_id,
size_t threshold) {
auto* table_ptr = GetTable(table_id);
std::pair<int64_t, int64_t> ret = table_ptr->PrintTableStat();
VLOG(0) << "table id: " << table_id << ", feasign size: " << ret.first
<< ", mf size: " << ret.second;
// > 50亿,40%内存
if (static_cast<size_t>(ret.first) > threshold) {
VLOG(0) << "run cache table";
table_ptr->CacheTable(pass_id);
}
return done();
}

4 changes: 3 additions & 1 deletion paddle/fluid/distributed/ps/service/ps_local_client.h
Original file line number Diff line number Diff line change
@@ -87,7 +87,9 @@ class PsLocalClient : public PSClient {
const std::vector<std::unordered_map<uint64_t, uint32_t>>& keys2rank_vec,
const uint16_t& dim_id = 0);

virtual ::std::future<int32_t> PrintTableStat(uint32_t table_id);
virtual ::std::future<int32_t> PrintTableStat(uint32_t table_id,
uint16_t pass_id,
size_t threshold);

virtual ::std::future<int32_t> SaveCacheTable(uint32_t table_id,
uint16_t pass_id,
6 changes: 3 additions & 3 deletions paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc
Original file line number Diff line number Diff line change
@@ -232,7 +232,7 @@ void CtrDymfAccessor::UpdateStatAfterSave(float* value, int param) {
int32_t CtrDymfAccessor::Create(float** values, size_t num) {
for (size_t value_item = 0; value_item < num; ++value_item) {
float* value = values[value_item];
#ifdef PADDLE_WITH_PSLIB
#if defined(PADDLE_WITH_PSLIB) || defined(PADDLE_WITH_HETERPS)
common_feature_value.UnseenDays(value) = 0;
common_feature_value.PassId(value) = 0;
#else
@@ -385,7 +385,7 @@ std::string CtrDymfAccessor::ParseToString(const float* v, int param) {

int CtrDymfAccessor::ParseFromString(const std::string& str, float* value) {
auto ret = paddle::string::str_to_float(str.data(), value);
#ifdef PADDLE_WITH_PSLIB
#if defined(PADDLE_WITH_PSLIB) || defined(PADDLE_WITH_HETERPS)
float unseen_day = value[common_feature_value.UnseenDaysIndex()];
common_feature_value.UnseenDays(value) = (uint16_t)(unseen_day);
common_feature_value.PassId(value) = 0;
@@ -437,7 +437,7 @@ void CtrDymfAccessor::UpdateTimeDecay(float* value, bool is_update_seen_day) {
}
}

#ifdef PADDLE_WITH_PSLIB
#if defined(PADDLE_WITH_PSLIB) || defined(PADDLE_WITH_HETERPS)
bool CtrDymfAccessor::SaveMemCache(float* value,
int param,
double global_cache_threshold,
4 changes: 2 additions & 2 deletions paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h
Original file line number Diff line number Diff line change
@@ -88,7 +88,7 @@ class CtrDymfAccessor : public ValueAccessor {
// 根据mf_dim计算的总byte数
int Size(int mf_dim) { return (Dim(mf_dim)) * sizeof(float); }

#ifdef PADDLE_WITH_PSLIB
#if defined(PADDLE_WITH_PSLIB) || defined(PADDLE_WITH_HETERPS)
uint16_t& PassId(float* val) {
uint16_t* int16_val =
reinterpret_cast<uint16_t*>(val + UnseenDaysIndex());
@@ -258,7 +258,7 @@ class CtrDymfAccessor : public ValueAccessor {

void SetDayId(int day_id) override;

#ifdef PADDLE_WITH_PSLIB
#if defined(PADDLE_WITH_PSLIB) || defined(PADDLE_WITH_HETERPS)
// 根据pass_id和show_threshold阈值来判断cache到ssd
bool SaveMemCache(float* value,
int param,
12 changes: 6 additions & 6 deletions paddle/fluid/distributed/ps/table/ssd_sparse_table.cc
Original file line number Diff line number Diff line change
@@ -265,7 +265,7 @@ int32_t SSDSparseTable::PullSparsePtr(int shard_id,
}

_value_accessor->UpdateTimeDecay(ret->data(), true);
#ifdef PADDLE_WITH_PSLIB
#if defined(PADDLE_WITH_PSLIB) || defined(PADDLE_WITH_HETERPS)
_value_accessor->UpdatePassId(ret->data(), pass_id);
#endif
int pull_data_idx = cur_ctx->batch_index[idx];
@@ -280,7 +280,7 @@ int32_t SSDSparseTable::PullSparsePtr(int shard_id,
ret = itr.value_ptr();
// int pull_data_idx = keys[i].second;
_value_accessor->UpdateTimeDecay(ret->data(), true);
#ifdef PADDLE_WITH_PSLIB
#if defined(PADDLE_WITH_PSLIB) || defined(PADDLE_WITH_HETERPS)
_value_accessor->UpdatePassId(ret->data(), pass_id);
#endif
pull_values[i] = reinterpret_cast<char*>(ret);
@@ -332,7 +332,7 @@ int32_t SSDSparseTable::PullSparsePtr(int shard_id,
ret = &feature_value;
}
_value_accessor->UpdateTimeDecay(ret->data(), true);
#ifdef PADDLE_WITH_PSLIB
#if defined(PADDLE_WITH_PSLIB) || defined(PADDLE_WITH_HETERPS)
_value_accessor->UpdatePassId(ret->data(), pass_id);
#endif
int pull_data_idx = cur_ctx->batch_index[idx];
@@ -2945,7 +2945,7 @@ int32_t SSDSparseTable::LoadWithBinary(const std::string& path, int param) {
abort();
}
last_k = k;
#ifdef PADDLE_WITH_PSLIB
#if defined(PADDLE_WITH_PSLIB) || defined(PADDLE_WITH_HETERPS)
_value_accessor->UpdatePassId(convert_value, 0);
#endif
rocksdb::Status status = sst_writer.Put(
@@ -2963,7 +2963,7 @@ int32_t SSDSparseTable::LoadWithBinary(const std::string& path, int param) {
}
} else {
auto& feature_value = shard[k];
#ifdef PADDLE_WITH_PSLIB
#if defined(PADDLE_WITH_PSLIB) || defined(PADDLE_WITH_HETERPS)
_value_accessor->UpdatePassId(convert_value, 0);
#endif
feature_value.resize(dim);
@@ -3051,7 +3051,7 @@ std::pair<int64_t, int64_t> SSDSparseTable::PrintTableStat() {

int32_t SSDSparseTable::CacheTable(uint16_t pass_id) {
std::lock_guard<std::mutex> guard(_table_mutex);
VLOG(0) << "cache_table";
VLOG(0) << "cache_table, pass_id:" << pass_id;
std::atomic<uint32_t> count{0};
std::vector<std::future<int>> tasks;

6 changes: 4 additions & 2 deletions paddle/fluid/distributed/ps/wrapper/fleet.cc
Original file line number Diff line number Diff line change
@@ -815,8 +815,10 @@ void FleetWrapper::RecvAndSaveTable(const uint64_t table_id,
}
}

void FleetWrapper::PrintTableStat(const uint64_t table_id) {
auto ret = worker_ptr_->PrintTableStat(table_id);
void FleetWrapper::PrintTableStat(const uint64_t table_id,
uint32_t pass_id,
size_t threshold) {
auto ret = worker_ptr_->PrintTableStat(table_id, pass_id, threshold);
ret.wait();
int32_t err_code = ret.get();
if (err_code == -1) {
4 changes: 3 additions & 1 deletion paddle/fluid/distributed/ps/wrapper/fleet.h
Original file line number Diff line number Diff line change
@@ -241,7 +241,9 @@ class FleetWrapper {
// barrier with barrier table
void BarrierWithTable(uint32_t barrier_type);

void PrintTableStat(const uint64_t table_id);
void PrintTableStat(const uint64_t table_id,
uint32_t pass_id,
size_t threshold);
void SaveCacheTable(const uint64_t table_id,
uint16_t pass_id,
size_t threshold);
18 changes: 17 additions & 1 deletion paddle/fluid/framework/fleet/ps_gpu_wrapper.h
Original file line number Diff line number Diff line change
@@ -14,8 +14,8 @@ limitations under the License. */

#pragma once
#ifdef PADDLE_WITH_HETERPS

#include <google/protobuf/text_format.h>
#include <stdlib.h>
#include <atomic>
#include <ctime>
#include <map>
@@ -390,6 +390,22 @@ class PSGPUWrapper {
if (s_instance_ != NULL && is_initialized_ == false) {
VLOG(3) << "PSGPUWrapper Begin InitializeGPU";
is_initialized_ = true;
#if defined(PADDLE_WITH_PSCORE) && defined(PADDLE_WITH_HETERPS) && \
defined(PADDLE_WITH_NCCL)
const char* launch_mode = std::getenv("NCCL_LAUNCH_MODE");
if (launch_mode != nullptr) {
if (std::string(launch_mode) == "PARALLEL") {
PADDLE_THROW(common::errors::Unavailable(
"on heterps-mode you must export NCCL_LAUNCH_MODE=GROUP for no "
"hang, but received [%s]",
launch_mode));
}
} else {
PADDLE_THROW(
common::errors::Unavailable("on heterps-mode you must export "
"NCCL_LAUNCH_MODE=GROUP for no hang"));
}
#endif
resource_ = std::make_shared<HeterPsResource>(dev_ids);
resource_->enable_p2p();
keys_tensor.resize(resource_->total_device());
1 change: 1 addition & 0 deletions paddle/fluid/pybind/fleet_py.cc
Original file line number Diff line number Diff line change
@@ -81,6 +81,7 @@ void BindDistFleetWrapper(py::module* m) {
.def("pull_fl_strategy", &FleetWrapper::PullFlStrategy)
.def("revert", &FleetWrapper::Revert)
.def("set_date", &FleetWrapper::SetDate)
.def("print_table_stat", &FleetWrapper::PrintTableStat)
.def("check_save_pre_patch_done", &FleetWrapper::CheckSavePrePatchDone);
}

1 change: 1 addition & 0 deletions python/paddle/distributed/fleet/__init__.py
Original file line number Diff line number Diff line change
@@ -99,6 +99,7 @@
load_inference_model = fleet.load_inference_model
load_one_table = fleet.load_one_table
set_date = fleet.set_date
print_table_stat = fleet.print_table_stat
minimize = fleet.minimize
distributed_model = distributed_model
shrink = fleet.shrink
21 changes: 21 additions & 0 deletions python/paddle/distributed/fleet/fleet.py
Original file line number Diff line number Diff line change
@@ -1419,6 +1419,27 @@ def set_date(self, table_id: int, day_id: str) -> None:
"""
self._runtime_handle._set_date(table_id, str(day_id))

@is_non_distributed_check
@inited_runtime_handler
def print_table_stat(self, table_id: int, pass_id: int, threshold: float):
"""
Print stat info of table_id for gpups table, format: tableid, feasign size, mf size.

Args:

table_id (int): The id of table.
pass_id (int): The id of pass.
threshold (float): The threshold of print.

Examples:

.. code-block:: text

fleet.print_table_stat(0,6,600000)

"""
self._runtime_handle._print_table_stat(table_id, pass_id, threshold)

@is_non_distributed_check
@inited_runtime_handler
def shrink(self, threshold: int | None = None) -> None:
6 changes: 6 additions & 0 deletions python/paddle/distributed/ps/the_one_ps.py
Original file line number Diff line number Diff line change
@@ -1760,6 +1760,12 @@ def _set_date(self, table_id, day_id):
self._worker.set_date(table_id, day_id)
fleet.util.barrier()

def _print_table_stat(self, table_id, pass_id, threshold):
fleet.util.barrier()
if self.role_maker._is_first_worker():
self._worker.print_table_stat(table_id, pass_id, threshold)
fleet.util.barrier()

def _shrink(self, threshold=None):
if threshold is not None:
warnings.warn(