Skip to content

Commit

Permalink
[Feature] enable to specify stream in UnitGraph::CopyTo() which could…
Browse files Browse the repository at this point in the history
… lead to async copy (dmlc#3297)

Co-authored-by: Minjie Wang <[email protected]>
  • Loading branch information
Rhett-Ying and jermainewang authored Sep 1, 2021
1 parent f4fe518 commit 5a24510
Show file tree
Hide file tree
Showing 9 changed files with 101 additions and 40 deletions.
9 changes: 5 additions & 4 deletions include/dgl/aten/coo.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,13 @@ struct COOMatrix {
}

/*! \brief Return a copy of this matrix on the give device context. */
inline COOMatrix CopyTo(const DLContext& ctx) const {
inline COOMatrix CopyTo(const DLContext &ctx,
const DGLStreamHandle &stream = nullptr) const {
if (ctx == row->ctx)
return *this;
return COOMatrix(num_rows, num_cols,
row.CopyTo(ctx), col.CopyTo(ctx),
aten::IsNullArray(data)? data : data.CopyTo(ctx),
return COOMatrix(num_rows, num_cols, row.CopyTo(ctx, stream),
col.CopyTo(ctx, stream),
aten::IsNullArray(data) ? data : data.CopyTo(ctx, stream),
row_sorted, col_sorted);
}
};
Expand Down
9 changes: 5 additions & 4 deletions include/dgl/aten/csr.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,13 @@ struct CSRMatrix {
}

/*! \brief Return a copy of this matrix on the give device context. */
inline CSRMatrix CopyTo(const DLContext& ctx) const {
inline CSRMatrix CopyTo(const DLContext &ctx,
const DGLStreamHandle &stream = nullptr) const {
if (ctx == indptr->ctx)
return *this;
return CSRMatrix(num_rows, num_cols,
indptr.CopyTo(ctx), indices.CopyTo(ctx),
aten::IsNullArray(data)? data : data.CopyTo(ctx),
return CSRMatrix(num_rows, num_cols, indptr.CopyTo(ctx, stream),
indices.CopyTo(ctx, stream),
aten::IsNullArray(data) ? data : data.CopyTo(ctx, stream),
sorted);
}
};
Expand Down
30 changes: 18 additions & 12 deletions include/dgl/runtime/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,18 +154,21 @@ class NDArray {
* \note The copy may happen asynchrously if it involves a GPU context.
* DGLSynchronize is necessary.
*/
inline void CopyTo(DLTensor* other) const;
inline void CopyTo(const NDArray& other) const;
inline void CopyTo(DLTensor *other,
const DGLStreamHandle &stream = nullptr) const;
inline void CopyTo(const NDArray &other,
const DGLStreamHandle &stream = nullptr) const;
/*!
* \brief Copy the data to another context.
* \param ctx The target context.
* \return The array under another context.
*/
inline NDArray CopyTo(const DLContext& ctx) const;
inline NDArray CopyTo(const DLContext &ctx,
const DGLStreamHandle &stream = nullptr) const;
/*!
* \brief Return a new array with a copy of the content.
*/
inline NDArray Clone() const;
inline NDArray Clone(const DGLStreamHandle &stream = nullptr) const;
/*!
* \brief Load NDArray from stream
* \param stream The input data stream
Expand Down Expand Up @@ -401,30 +404,33 @@ inline void NDArray::CopyFrom(const NDArray& other,
CopyFromTo(&(other.data_->dl_tensor), &(data_->dl_tensor), stream);
}

inline void NDArray::CopyTo(DLTensor* other) const {
inline void NDArray::CopyTo(DLTensor *other,
const DGLStreamHandle &stream) const {
CHECK(data_ != nullptr);
CopyFromTo(&(data_->dl_tensor), other);
CopyFromTo(&(data_->dl_tensor), other, stream);
}

inline void NDArray::CopyTo(const NDArray& other) const {
inline void NDArray::CopyTo(const NDArray &other,
const DGLStreamHandle &stream) const {
CHECK(data_ != nullptr);
CHECK(other.data_ != nullptr);
CopyFromTo(&(data_->dl_tensor), &(other.data_->dl_tensor));
CopyFromTo(&(data_->dl_tensor), &(other.data_->dl_tensor), stream);
}

inline NDArray NDArray::CopyTo(const DLContext& ctx) const {
inline NDArray NDArray::CopyTo(const DLContext &ctx,
const DGLStreamHandle &stream) const {
CHECK(data_ != nullptr);
const DLTensor* dptr = operator->();
NDArray ret = Empty(std::vector<int64_t>(dptr->shape, dptr->shape + dptr->ndim),
dptr->dtype, ctx);
this->CopyTo(ret);
this->CopyTo(ret, stream);
return ret;
}

inline NDArray NDArray::Clone() const {
inline NDArray NDArray::Clone(const DGLStreamHandle &stream) const {
CHECK(data_ != nullptr);
const DLTensor* dptr = operator->();
return this->CopyTo(dptr->ctx);
return this->CopyTo(dptr->ctx, stream);
}

inline int NDArray::use_count() const {
Expand Down
5 changes: 3 additions & 2 deletions src/graph/heterograph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,15 +254,16 @@ HeteroGraphPtr HeteroGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
hgindex->num_verts_per_type_));
}

HeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DLContext& ctx) {
HeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DLContext &ctx,
const DGLStreamHandle &stream) {
if (ctx == g->Context()) {
return g;
}
auto hgindex = std::dynamic_pointer_cast<HeteroGraph>(g);
CHECK_NOTNULL(hgindex);
std::vector<HeteroGraphPtr> rel_graphs;
for (auto g : hgindex->relation_graphs_) {
rel_graphs.push_back(UnitGraph::CopyTo(g, ctx));
rel_graphs.push_back(UnitGraph::CopyTo(g, ctx, stream));
}
return HeteroGraphPtr(new HeteroGraph(hgindex->meta_graph_, rel_graphs,
hgindex->num_verts_per_type_));
Expand Down
3 changes: 2 additions & 1 deletion src/graph/heterograph.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,8 @@ class HeteroGraph : public BaseHeteroGraph {
static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);

/*! \brief Copy the data to another context */
static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext& ctx);
static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext &ctx,
const DGLStreamHandle &stream = nullptr);

/*! \brief Copy the data to shared memory.
*
Expand Down
28 changes: 17 additions & 11 deletions src/graph/unit_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,11 @@ class UnitGraph::COO : public BaseHeteroGraph {
return ret;
}

COO CopyTo(const DLContext& ctx) const {
COO CopyTo(const DLContext &ctx,
const DGLStreamHandle &stream = nullptr) const {
if (Context() == ctx)
return *this;
return COO(meta_graph_, adj_.CopyTo(ctx));
return COO(meta_graph_, adj_.CopyTo(ctx, stream));
}

bool IsMultigraph() const override {
Expand Down Expand Up @@ -537,11 +538,12 @@ class UnitGraph::CSR : public BaseHeteroGraph {
}
}

CSR CopyTo(const DLContext& ctx) const {
CSR CopyTo(const DLContext &ctx,
const DGLStreamHandle &stream = nullptr) const {
if (Context() == ctx) {
return *this;
} else {
return CSR(meta_graph_, adj_.CopyTo(ctx));
return CSR(meta_graph_, adj_.CopyTo(ctx, stream));
}
}

Expand Down Expand Up @@ -1232,18 +1234,22 @@ HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
}
}

HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DLContext& ctx) {
HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DLContext &ctx,
const DGLStreamHandle &stream) {
if (ctx == g->Context()) {
return g;
} else {
auto bg = std::dynamic_pointer_cast<UnitGraph>(g);
CHECK_NOTNULL(bg);
CSRPtr new_incsr =
(bg->in_csr_->defined())? CSRPtr(new CSR(bg->in_csr_->CopyTo(ctx))) : nullptr;
CSRPtr new_outcsr =
(bg->out_csr_->defined())? CSRPtr(new CSR(bg->out_csr_->CopyTo(ctx))) : nullptr;
COOPtr new_coo =
(bg->coo_->defined())? COOPtr(new COO(bg->coo_->CopyTo(ctx))) : nullptr;
CSRPtr new_incsr = (bg->in_csr_->defined())
? CSRPtr(new CSR(bg->in_csr_->CopyTo(ctx, stream)))
: nullptr;
CSRPtr new_outcsr = (bg->out_csr_->defined())
? CSRPtr(new CSR(bg->out_csr_->CopyTo(ctx, stream)))
: nullptr;
COOPtr new_coo = (bg->coo_->defined())
? COOPtr(new COO(bg->coo_->CopyTo(ctx, stream)))
: nullptr;
return HeteroGraphPtr(
new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->formats_));
}
Expand Down
3 changes: 2 additions & 1 deletion src/graph/unit_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ class UnitGraph : public BaseHeteroGraph {
static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);

/*! \brief Copy the data to another context */
static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext& ctx);
static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext &ctx,
const DGLStreamHandle &stream = nullptr);

/*!
* \brief Create in-edge CSR format of the unit graph.
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/cpu_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class CPUDeviceAPI final : public DeviceAPI {
size);
}

DGLStreamHandle CreateStream(DGLContext) final { return nullptr; }

void StreamSync(DGLContext ctx, DGLStreamHandle stream) final {
}

Expand Down
52 changes: 47 additions & 5 deletions tests/cpp/test_unit_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
* \file test_unit_graph.cc
* \brief Test UnitGraph
*/
#include <gtest/gtest.h>
#include "../../src/graph/unit_graph.h"
#include "./../src/graph/heterograph.h"
#include "./common.h"
#include <dgl/array.h>
#include <dgl/immutable_graph.h>
#include <dgl/runtime/device_api.h>
#include <gtest/gtest.h>
#include <memory>
#include <vector>
#include <dgl/immutable_graph.h>
#include "./common.h"
#include "./../src/graph/heterograph.h"
#include "../../src/graph/unit_graph.h"

using namespace dgl;
using namespace dgl::runtime;
Expand Down Expand Up @@ -298,6 +299,47 @@ void _TestUnitGraph_Reserve(DLContext ctx) {
ASSERT_TRUE(g_out_csr.indices->data == r_g_in_csr.indices->data);
}

template <typename IdType>
void _TestUnitGraph_CopyTo(const DLContext &src_ctx,
const DGLContext &dst_ctx) {
const aten::CSRMatrix &csr = CSR1<IdType>(src_ctx);
const aten::COOMatrix &coo = COO1<IdType>(src_ctx);

auto device = dgl::runtime::DeviceAPI::Get(dst_ctx);
auto stream = device->CreateStream(dst_ctx);

auto g = dgl::UnitGraph::CreateFromCSC(2, csr);
ASSERT_EQ(g->GetCreatedFormats(), 4);
auto cg = dgl::UnitGraph::CopyTo(g, dst_ctx, stream);
device->StreamSync(dst_ctx, stream);
ASSERT_EQ(cg->GetCreatedFormats(), 4);

g = dgl::UnitGraph::CreateFromCSR(2, csr);
ASSERT_EQ(g->GetCreatedFormats(), 2);
cg = dgl::UnitGraph::CopyTo(g, dst_ctx, stream);
device->StreamSync(dst_ctx, stream);
ASSERT_EQ(cg->GetCreatedFormats(), 2);

g = dgl::UnitGraph::CreateFromCOO(2, coo);
ASSERT_EQ(g->GetCreatedFormats(), 1);
cg = dgl::UnitGraph::CopyTo(g, dst_ctx, stream);
device->StreamSync(dst_ctx, stream);
ASSERT_EQ(cg->GetCreatedFormats(), 1);
}

TEST(UniGraphTest, TestUnitGraph_CopyTo) {
_TestUnitGraph_CopyTo<int32_t>(CPU, CPU);
_TestUnitGraph_CopyTo<int64_t>(CPU, CPU);
#ifdef DGL_USE_CUDA
_TestUnitGraph_CopyTo<int32_t>(CPU, GPU);
_TestUnitGraph_CopyTo<int32_t>(GPU, GPU);
_TestUnitGraph_CopyTo<int32_t>(GPU, CPU);
_TestUnitGraph_CopyTo<int64_t>(CPU, GPU);
_TestUnitGraph_CopyTo<int64_t>(GPU, GPU);
_TestUnitGraph_CopyTo<int64_t>(GPU, CPU);
#endif
}

TEST(UniGraphTest, TestUnitGraph_Create) {
_TestUnitGraph<int32_t>(CPU);
_TestUnitGraph<int64_t>(CPU);
Expand Down

0 comments on commit 5a24510

Please sign in to comment.