diff --git a/include/dgl/aten/coo.h b/include/dgl/aten/coo.h index 43d3bfaa90d5..79859955b8db 100644 --- a/include/dgl/aten/coo.h +++ b/include/dgl/aten/coo.h @@ -120,12 +120,13 @@ struct COOMatrix { } /*! \brief Return a copy of this matrix on the give device context. */ - inline COOMatrix CopyTo(const DLContext& ctx) const { + inline COOMatrix CopyTo(const DLContext &ctx, + const DGLStreamHandle &stream = nullptr) const { if (ctx == row->ctx) return *this; - return COOMatrix(num_rows, num_cols, - row.CopyTo(ctx), col.CopyTo(ctx), - aten::IsNullArray(data)? data : data.CopyTo(ctx), + return COOMatrix(num_rows, num_cols, row.CopyTo(ctx, stream), + col.CopyTo(ctx, stream), + aten::IsNullArray(data) ? data : data.CopyTo(ctx, stream), row_sorted, col_sorted); } }; diff --git a/include/dgl/aten/csr.h b/include/dgl/aten/csr.h index 7e028e0edc56..6e01a6d7bc78 100644 --- a/include/dgl/aten/csr.h +++ b/include/dgl/aten/csr.h @@ -113,12 +113,13 @@ struct CSRMatrix { } /*! \brief Return a copy of this matrix on the give device context. */ - inline CSRMatrix CopyTo(const DLContext& ctx) const { + inline CSRMatrix CopyTo(const DLContext &ctx, + const DGLStreamHandle &stream = nullptr) const { if (ctx == indptr->ctx) return *this; - return CSRMatrix(num_rows, num_cols, - indptr.CopyTo(ctx), indices.CopyTo(ctx), - aten::IsNullArray(data)? data : data.CopyTo(ctx), + return CSRMatrix(num_rows, num_cols, indptr.CopyTo(ctx, stream), + indices.CopyTo(ctx, stream), + aten::IsNullArray(data) ? data : data.CopyTo(ctx, stream), sorted); } }; diff --git a/include/dgl/runtime/ndarray.h b/include/dgl/runtime/ndarray.h index cbffcdfc2b8f..e64b778f1dee 100644 --- a/include/dgl/runtime/ndarray.h +++ b/include/dgl/runtime/ndarray.h @@ -154,18 +154,21 @@ class NDArray { * \note The copy may happen asynchrously if it involves a GPU context. * DGLSynchronize is necessary. */ - inline void CopyTo(DLTensor* other) const; - inline void CopyTo(const NDArray& other) const; + inline void CopyTo(DLTensor *other, + const DGLStreamHandle &stream = nullptr) const; + inline void CopyTo(const NDArray &other, + const DGLStreamHandle &stream = nullptr) const; /*! * \brief Copy the data to another context. * \param ctx The target context. * \return The array under another context. */ - inline NDArray CopyTo(const DLContext& ctx) const; + inline NDArray CopyTo(const DLContext &ctx, + const DGLStreamHandle &stream = nullptr) const; /*! * \brief Return a new array with a copy of the content. */ - inline NDArray Clone() const; + inline NDArray Clone(const DGLStreamHandle &stream = nullptr) const; /*! * \brief Load NDArray from stream * \param stream The input data stream @@ -401,30 +404,33 @@ inline void NDArray::CopyFrom(const NDArray& other, CopyFromTo(&(other.data_->dl_tensor), &(data_->dl_tensor), stream); } -inline void NDArray::CopyTo(DLTensor* other) const { +inline void NDArray::CopyTo(DLTensor *other, + const DGLStreamHandle &stream) const { CHECK(data_ != nullptr); - CopyFromTo(&(data_->dl_tensor), other); + CopyFromTo(&(data_->dl_tensor), other, stream); } -inline void NDArray::CopyTo(const NDArray& other) const { +inline void NDArray::CopyTo(const NDArray &other, + const DGLStreamHandle &stream) const { CHECK(data_ != nullptr); CHECK(other.data_ != nullptr); - CopyFromTo(&(data_->dl_tensor), &(other.data_->dl_tensor)); + CopyFromTo(&(data_->dl_tensor), &(other.data_->dl_tensor), stream); } -inline NDArray NDArray::CopyTo(const DLContext& ctx) const { +inline NDArray NDArray::CopyTo(const DLContext &ctx, + const DGLStreamHandle &stream) const { CHECK(data_ != nullptr); const DLTensor* dptr = operator->(); NDArray ret = Empty(std::vector(dptr->shape, dptr->shape + dptr->ndim), dptr->dtype, ctx); - this->CopyTo(ret); + this->CopyTo(ret, stream); return ret; } -inline NDArray NDArray::Clone() const { +inline NDArray NDArray::Clone(const DGLStreamHandle &stream) const { CHECK(data_ != nullptr); const DLTensor* dptr = operator->(); - return this->CopyTo(dptr->ctx); + return this->CopyTo(dptr->ctx, stream); } inline int NDArray::use_count() const { diff --git a/src/graph/heterograph.cc b/src/graph/heterograph.cc index c518ea1ce281..534e22f64a74 100644 --- a/src/graph/heterograph.cc +++ b/src/graph/heterograph.cc @@ -254,7 +254,8 @@ HeteroGraphPtr HeteroGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) { hgindex->num_verts_per_type_)); } -HeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DLContext& ctx) { +HeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DLContext &ctx, + const DGLStreamHandle &stream) { if (ctx == g->Context()) { return g; } @@ -262,7 +263,7 @@ HeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DLContext& ctx) { CHECK_NOTNULL(hgindex); std::vector rel_graphs; for (auto g : hgindex->relation_graphs_) { - rel_graphs.push_back(UnitGraph::CopyTo(g, ctx)); + rel_graphs.push_back(UnitGraph::CopyTo(g, ctx, stream)); } return HeteroGraphPtr(new HeteroGraph(hgindex->meta_graph_, rel_graphs, hgindex->num_verts_per_type_)); diff --git a/src/graph/heterograph.h b/src/graph/heterograph.h index 02657c807282..832f91e0600f 100644 --- a/src/graph/heterograph.h +++ b/src/graph/heterograph.h @@ -225,7 +225,8 @@ class HeteroGraph : public BaseHeteroGraph { static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits); /*! \brief Copy the data to another context */ - static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext& ctx); + static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext &ctx, + const DGLStreamHandle &stream = nullptr); /*! \brief Copy the data to shared memory. * diff --git a/src/graph/unit_graph.cc b/src/graph/unit_graph.cc index f59ce08f45e0..727b62c402e3 100644 --- a/src/graph/unit_graph.cc +++ b/src/graph/unit_graph.cc @@ -149,10 +149,11 @@ class UnitGraph::COO : public BaseHeteroGraph { return ret; } - COO CopyTo(const DLContext& ctx) const { + COO CopyTo(const DLContext &ctx, + const DGLStreamHandle &stream = nullptr) const { if (Context() == ctx) return *this; - return COO(meta_graph_, adj_.CopyTo(ctx)); + return COO(meta_graph_, adj_.CopyTo(ctx, stream)); } bool IsMultigraph() const override { @@ -537,11 +538,12 @@ class UnitGraph::CSR : public BaseHeteroGraph { } } - CSR CopyTo(const DLContext& ctx) const { + CSR CopyTo(const DLContext &ctx, + const DGLStreamHandle &stream = nullptr) const { if (Context() == ctx) { return *this; } else { - return CSR(meta_graph_, adj_.CopyTo(ctx)); + return CSR(meta_graph_, adj_.CopyTo(ctx, stream)); } } @@ -1232,18 +1234,22 @@ HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) { } } -HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DLContext& ctx) { +HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DLContext &ctx, + const DGLStreamHandle &stream) { if (ctx == g->Context()) { return g; } else { auto bg = std::dynamic_pointer_cast(g); CHECK_NOTNULL(bg); - CSRPtr new_incsr = - (bg->in_csr_->defined())? CSRPtr(new CSR(bg->in_csr_->CopyTo(ctx))) : nullptr; - CSRPtr new_outcsr = - (bg->out_csr_->defined())? CSRPtr(new CSR(bg->out_csr_->CopyTo(ctx))) : nullptr; - COOPtr new_coo = - (bg->coo_->defined())? COOPtr(new COO(bg->coo_->CopyTo(ctx))) : nullptr; + CSRPtr new_incsr = (bg->in_csr_->defined()) + ? CSRPtr(new CSR(bg->in_csr_->CopyTo(ctx, stream))) + : nullptr; + CSRPtr new_outcsr = (bg->out_csr_->defined()) + ? CSRPtr(new CSR(bg->out_csr_->CopyTo(ctx, stream))) + : nullptr; + COOPtr new_coo = (bg->coo_->defined()) + ? COOPtr(new COO(bg->coo_->CopyTo(ctx, stream))) + : nullptr; return HeteroGraphPtr( new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->formats_)); } diff --git a/src/graph/unit_graph.h b/src/graph/unit_graph.h index 45970ed45896..e76e3f6e2c01 100644 --- a/src/graph/unit_graph.h +++ b/src/graph/unit_graph.h @@ -205,7 +205,8 @@ class UnitGraph : public BaseHeteroGraph { static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits); /*! \brief Copy the data to another context */ - static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext& ctx); + static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext &ctx, + const DGLStreamHandle &stream = nullptr); /*! * \brief Create in-edge CSR format of the unit graph. diff --git a/src/runtime/cpu_device_api.cc b/src/runtime/cpu_device_api.cc index 74b96d060d89..a0d341168f87 100644 --- a/src/runtime/cpu_device_api.cc +++ b/src/runtime/cpu_device_api.cc @@ -60,6 +60,8 @@ class CPUDeviceAPI final : public DeviceAPI { size); } + DGLStreamHandle CreateStream(DGLContext) final { return nullptr; } + void StreamSync(DGLContext ctx, DGLStreamHandle stream) final { } diff --git a/tests/cpp/test_unit_graph.cc b/tests/cpp/test_unit_graph.cc index 86b6703f61c8..d9b092046d26 100644 --- a/tests/cpp/test_unit_graph.cc +++ b/tests/cpp/test_unit_graph.cc @@ -3,14 +3,15 @@ * \file test_unit_graph.cc * \brief Test UnitGraph */ -#include +#include "../../src/graph/unit_graph.h" +#include "./../src/graph/heterograph.h" +#include "./common.h" #include +#include +#include +#include #include #include -#include -#include "./common.h" -#include "./../src/graph/heterograph.h" -#include "../../src/graph/unit_graph.h" using namespace dgl; using namespace dgl::runtime; @@ -298,6 +299,47 @@ void _TestUnitGraph_Reserve(DLContext ctx) { ASSERT_TRUE(g_out_csr.indices->data == r_g_in_csr.indices->data); } +template +void _TestUnitGraph_CopyTo(const DLContext &src_ctx, + const DGLContext &dst_ctx) { + const aten::CSRMatrix &csr = CSR1(src_ctx); + const aten::COOMatrix &coo = COO1(src_ctx); + + auto device = dgl::runtime::DeviceAPI::Get(dst_ctx); + auto stream = device->CreateStream(dst_ctx); + + auto g = dgl::UnitGraph::CreateFromCSC(2, csr); + ASSERT_EQ(g->GetCreatedFormats(), 4); + auto cg = dgl::UnitGraph::CopyTo(g, dst_ctx, stream); + device->StreamSync(dst_ctx, stream); + ASSERT_EQ(cg->GetCreatedFormats(), 4); + + g = dgl::UnitGraph::CreateFromCSR(2, csr); + ASSERT_EQ(g->GetCreatedFormats(), 2); + cg = dgl::UnitGraph::CopyTo(g, dst_ctx, stream); + device->StreamSync(dst_ctx, stream); + ASSERT_EQ(cg->GetCreatedFormats(), 2); + + g = dgl::UnitGraph::CreateFromCOO(2, coo); + ASSERT_EQ(g->GetCreatedFormats(), 1); + cg = dgl::UnitGraph::CopyTo(g, dst_ctx, stream); + device->StreamSync(dst_ctx, stream); + ASSERT_EQ(cg->GetCreatedFormats(), 1); +} + +TEST(UniGraphTest, TestUnitGraph_CopyTo) { + _TestUnitGraph_CopyTo(CPU, CPU); + _TestUnitGraph_CopyTo(CPU, CPU); +#ifdef DGL_USE_CUDA + _TestUnitGraph_CopyTo(CPU, GPU); + _TestUnitGraph_CopyTo(GPU, GPU); + _TestUnitGraph_CopyTo(GPU, CPU); + _TestUnitGraph_CopyTo(CPU, GPU); + _TestUnitGraph_CopyTo(GPU, GPU); + _TestUnitGraph_CopyTo(GPU, CPU); +#endif +} + TEST(UniGraphTest, TestUnitGraph_Create) { _TestUnitGraph(CPU); _TestUnitGraph(CPU);