Skip to content

Commit

Permalink
[Data format] Serialization for Immutable Graph and HeteroGraph (dmlc…
Browse files Browse the repository at this point in the history
…#1254)

* graph format

* fix lint

* lint

* fix

* unit test

* lint

* add magic num

* move serialize out of struct

* lint

* serialize

* trigger CI

* fix lint

* lint

Co-authored-by: zhoujinjing09 <[email protected]>
  • Loading branch information
VoVAllen and zhoujinjing09 authored Feb 10, 2020
1 parent ffe5898 commit 23893bb
Show file tree
Hide file tree
Showing 8 changed files with 194 additions and 14 deletions.
10 changes: 10 additions & 0 deletions include/dgl/immutable_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,12 @@ class ImmutableGraph: public GraphInterface {
*/
ImmutableGraphPtr Reverse() const;

/*! \return Load HeteroGraph from stream, using CSRMatrix*/
bool Load(dmlc::Stream* fs);

/*! \return Save HeteroGraph to stream, using CSRMatrix */
void Save(dmlc::Stream* fs) const;

void SortCSR() {
GetInCSR()->SortCSR();
GetOutCSR()->SortCSR();
Expand Down Expand Up @@ -1028,4 +1034,8 @@ CSR::CSR(int64_t num_vertices, int64_t num_edges,

} // namespace dgl

namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::ImmutableGraph, true);
} // namespace dmlc

#endif // DGL_IMMUTABLE_GRAPH_H_
25 changes: 25 additions & 0 deletions src/graph/graph_serializer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*!
* Copyright (c) 2018 by Contributors
* \file graph/graph_serializer.cc
* \brief DGL serializer APIs
*/

#pragma once

#include <dgl/immutable_graph.h>
#include "heterograph.h"
#include "unit_graph.h"

namespace dgl {

class Serializer {
public:
static HeteroGraph* EmptyHeteroGraph() { return new HeteroGraph(); }
static ImmutableGraph* EmptyImmutableGraph() {
return new ImmutableGraph(static_cast<COOPtr>(nullptr));
}
static UnitGraph* EmptyUnitGraph() {
return UnitGraph::EmptyGraph();
}
};
} // namespace dgl
39 changes: 39 additions & 0 deletions src/graph/heterograph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@
* \brief Heterograph implementation
*/
#include "./heterograph.h"
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
#include <dgl/array.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
#include <dgl/immutable_graph.h>
#include <vector>
#include <tuple>
#include <utility>
#include "../c_api_common.h"
#include "./unit_graph.h"
#include "graph_serializer.h"
// TODO(BarclayII): currently CompactGraphs depend on IdHashMap implementation which
// only works on CPU. Should fix later to make it device agnostic.
#include "../array/cpu/array_utils.h"
Expand All @@ -21,6 +25,8 @@ using namespace dgl::runtime;
namespace dgl {
namespace {

using dgl::ImmutableGraph;

HeteroSubgraph EdgeSubgraphPreserveNodes(
const HeteroGraph* hg, const std::vector<IdArray>& eids) {
CHECK_EQ(eids.size(), hg->NumEdgeTypes())
Expand Down Expand Up @@ -494,6 +500,39 @@ CompactGraphs(const std::vector<HeteroGraphPtr> &graphs) {
return result;
}

constexpr uint64_t kDGLSerialize_HeteroGraph = 0xDD589FBE35224ABF;

bool HeteroGraph::Load(dmlc::Stream* fs) {
uint64_t magicNum;
CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
CHECK_EQ(magicNum, kDGLSerialize_HeteroGraph) << "Invalid HeteroGraph Data";
auto meta_grptr = new ImmutableGraph(static_cast<COOPtr>(nullptr));
CHECK(fs->Read(meta_grptr)) << "Invalid Immutable Graph Data";
uint64_t num_relation_graphs;
CHECK(fs->Read(&num_relation_graphs)) << "Invalid num of relation graphs";
std::vector<HeteroGraphPtr> relgraphs;
for (size_t i = 0; i < num_relation_graphs; ++i) {
UnitGraph* ugptr = Serializer::EmptyUnitGraph();
CHECK(fs->Read(ugptr)) << "Invalid UnitGraph Data";
relgraphs.emplace_back(dynamic_cast<BaseHeteroGraph*>(ugptr));
}
HeteroGraph* hgptr = new HeteroGraph(GraphPtr(meta_grptr), relgraphs);
*this = *hgptr;
return true;
}

void HeteroGraph::Save(dmlc::Stream* fs) const {
fs->Write(kDGLSerialize_HeteroGraph);
auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());
ImmutableGraph* meta_rptr = meta_graph_ptr.get();
fs->Write(*meta_rptr);
fs->Write(static_cast<uint64_t>(relation_graphs_.size()));
for (auto hptr : relation_graphs_) {
auto rptr = dynamic_cast<UnitGraph*>(hptr.get());
fs->Write(*rptr);
}
}

///////////////////////// C APIs /////////////////////////

DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCOO")
Expand Down
19 changes: 19 additions & 0 deletions src/graph/heterograph.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,20 @@ class HeteroGraph : public BaseHeteroGraph {

FlattenedHeteroGraphPtr Flatten(const std::vector<dgl_type_t>& etypes) const override;

/*! \return Load HeteroGraph from stream, using CSRMatrix*/
bool Load(dmlc::Stream* fs);

/*! \return Save HeteroGraph to stream, using CSRMatrix */
void Save(dmlc::Stream* fs) const;


private:
// To create empty class
friend class Serializer;

// Empty Constructor, only for serializer
HeteroGraph() : BaseHeteroGraph(static_cast<GraphPtr>(nullptr)) {}

/*! \brief A map from edge type to unit graph */
std::vector<HeteroGraphPtr> relation_graphs_;

Expand All @@ -183,4 +196,10 @@ class HeteroGraph : public BaseHeteroGraph {

} // namespace dgl


namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::HeteroGraph, true);
} // namespace dmlc


#endif // DGL_GRAPH_HETEROGRAPH_H_
24 changes: 24 additions & 0 deletions src/graph/immutable_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#include <dgl/packed_func_ext.h>
#include <dgl/immutable_graph.h>
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
#include <string.h>
#include <bitset>
#include <numeric>
Expand Down Expand Up @@ -634,6 +636,28 @@ ImmutableGraphPtr ImmutableGraph::Reverse() const {
}
}

constexpr uint64_t kDGLSerialize_ImGraph = 0xDD3c5FFE20046ABF;

/*! \return Load HeteroGraph from stream, using OutCSR Matrix*/
bool ImmutableGraph::Load(dmlc::Stream *fs) {
uint64_t magicNum;
aten::CSRMatrix out_csr_matrix;
CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
CHECK_EQ(magicNum, kDGLSerialize_ImGraph) << "Invalid ImmutableGraph Data";
CHECK(fs->Read(&out_csr_matrix)) << "Invalid csr matrix";
CSRPtr csr(new CSR(out_csr_matrix.indptr, out_csr_matrix.indices,
out_csr_matrix.data));
auto g = new ImmutableGraph(nullptr, csr);
*this = *g;
return true;
}

/*! \return Save HeteroGraph to stream, using OutCSR Matrix */
void ImmutableGraph::Save(dmlc::Stream *fs) const {
fs->Write(kDGLSerialize_ImGraph);
fs->Write(GetOutCSR()->ToCSRMatrix());
}

DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyTo")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
Expand Down
24 changes: 17 additions & 7 deletions src/graph/unit_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
namespace dgl {

namespace {

using namespace dgl::aten;

// create metagraph of one node type
inline GraphPtr CreateUnitGraphMetaGraph1() {
// a self-loop edge 0->0
Expand Down Expand Up @@ -1152,8 +1155,17 @@ SparseFormat UnitGraph::SelectFormat(SparseFormat preferred_format) const {
return SparseFormat::COO;
}

UnitGraph* UnitGraph::EmptyGraph() {
auto src = NewIdArray(0);
auto dst = NewIdArray(0);
auto mg = CreateUnitGraphMetaGraph(1);
COOPtr coo(new COO(mg, 0, 0, src, dst));
return new UnitGraph(mg, nullptr, nullptr, coo);
}

constexpr uint64_t kDGLSerialize_UnitGraphMagic = 0xDD2E60F0F6B4A127;

// Using OurCSR
bool UnitGraph::Load(dmlc::Stream* fs) {
uint64_t magicNum;
CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
Expand All @@ -1164,27 +1176,25 @@ bool UnitGraph::Load(dmlc::Stream* fs) {
CHECK(fs->Read(&num_dst)) << "Invalid num_dst";
aten::CSRMatrix csr_matrix;
CHECK(fs->Read(&csr_matrix)) << "Invalid csr_matrix";
SparseFormat restrict_format;
CHECK(fs->Read(&restrict_format)) << "Invalid restrict_format";
auto mg = CreateUnitGraphMetaGraph(num_vtypes);
CSRPtr csr(new CSR(mg, num_src, num_dst, csr_matrix.indptr, csr_matrix.indices, csr_matrix.data));
*this = UnitGraph(mg, csr, nullptr, nullptr);
CSRPtr csr(new CSR(mg, num_src, num_dst, csr_matrix.indptr,
csr_matrix.indices, csr_matrix.data));
*this = UnitGraph(mg, nullptr, csr, nullptr);
return true;
}

// Using Out CSR
void UnitGraph::Save(dmlc::Stream* fs) const {
// Following CreateFromCSR signature
aten::CSRMatrix csr_matrix = GetInCSRMatrix();
aten::CSRMatrix csr_matrix = GetOutCSRMatrix();
uint64_t num_vtypes = NumVertexTypes();
uint64_t num_src = NumVertices(SrcType());
uint64_t num_dst = NumVertices(DstType());
SparseFormat restrict_format = restrict_format_;
fs->Write(kDGLSerialize_UnitGraphMagic);
fs->Write(num_vtypes);
fs->Write(num_src);
fs->Write(num_dst);
fs->Write(csr_matrix);
fs->Write(restrict_format);
}

} // namespace dgl
5 changes: 5 additions & 0 deletions src/graph/unit_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ class UnitGraph : public BaseHeteroGraph {
void Save(dmlc::Stream* fs) const;

private:
friend class Serializer;

/*!
* \brief constructor
* \param metagraph metagraph
Expand Down Expand Up @@ -219,6 +221,9 @@ class UnitGraph : public BaseHeteroGraph {
/*! \return Whether the graph is hypersparse */
bool IsHypersparse() const;

// Empty Graph for Serializer Usgae
static UnitGraph* EmptyGraph();

// Graph stored in different format. We use an on-demand strategy: the format is
// only materialized if the operation that suitable for it is invoked.
/*! \brief CSR graph that stores reverse edges */
Expand Down
62 changes: 55 additions & 7 deletions tests/cpp/test_serialize.cc
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#include <dgl/immutable_graph.h>
#include <dmlc/memory_io.h>
#include <gtest/gtest.h>
#include <algorithm>
#include <iostream>
#include <vector>
#include "../../src/graph/graph_serializer.h"
#include "../../src/graph/heterograph.h"
#include "../../src/graph/unit_graph.h"
#include "./common.h"

Expand All @@ -22,13 +25,58 @@ TEST(Serialize, UnitGraph) {
static_cast<dmlc::Stream*>(&ifs)->Write<UnitGraph>(*ug);

dmlc::MemoryStringStream ofs(&blob);
src = NewIdArray(0);
dst = NewIdArray(0);
auto mg2 = dgl::UnitGraph::CreateFromCOO(
1, 0, 0, src, dst); // Any way to construct Empty UnitGraph?
UnitGraph* ug2 = dynamic_cast<UnitGraph*>(mg2.get());
UnitGraph* ug2 = Serializer::EmptyUnitGraph();
static_cast<dmlc::Stream*>(&ofs)->Read(ug2);
EXPECT_EQ(ug2->NumVertices(0), 8);
EXPECT_EQ(ug2->NumVertices(1), 9);
EXPECT_EQ(ug2->NumVertices(0), 9);
EXPECT_EQ(ug2->NumVertices(1), 8);
EXPECT_EQ(ug2->NumEdges(0), 4);
EXPECT_EQ(ug2->FindEdge(0, 1).first, 2);
EXPECT_EQ(ug2->FindEdge(0, 1).second, 6);
}

TEST(Serialize, ImmutableGraph) {
auto src = VecToIdArray<int64_t>({1, 2, 5, 3});
auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});
auto gptr = ImmutableGraph::CreateFromCOO(10, src, dst);
ImmutableGraph* rptr = gptr.get();

std::string blob;
dmlc::MemoryStringStream ifs(&blob);

static_cast<dmlc::Stream*>(&ifs)->Write(*rptr);

dmlc::MemoryStringStream ofs(&blob);
ImmutableGraph* rptr_read = new ImmutableGraph(static_cast<COOPtr>(nullptr));
static_cast<dmlc::Stream*>(&ofs)->Read(rptr_read);
EXPECT_EQ(rptr_read->NumEdges(), 4);
EXPECT_EQ(rptr_read->NumVertices(), 10);
EXPECT_EQ(rptr_read->FindEdge(2).first, 5);
EXPECT_EQ(rptr_read->FindEdge(2).second, 2);
}

TEST(Serialize, HeteroGraph) {
auto src = VecToIdArray<int64_t>({1, 2, 5, 3});
auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});
auto mg1 = dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst);
src = VecToIdArray<int64_t>({6, 2, 5, 1, 9});
dst = VecToIdArray<int64_t>({5, 2, 4, 9, 0});
auto mg2 = dgl::UnitGraph::CreateFromCOO(1, 9, 9, src, dst);
std::vector<HeteroGraphPtr> relgraphs;
relgraphs.push_back(mg1);
relgraphs.push_back(mg2);
src = VecToIdArray<int64_t>({0, 0});
dst = VecToIdArray<int64_t>({1, 0});
auto meta_gptr = ImmutableGraph::CreateFromCOO(2, src, dst);
HeteroGraph* hrptr = new HeteroGraph(meta_gptr, relgraphs);

std::string blob;
dmlc::MemoryStringStream ifs(&blob);
static_cast<dmlc::Stream*>(&ifs)->Write(*hrptr);

dmlc::MemoryStringStream ofs(&blob);
HeteroGraph* gptr = dgl::Serializer::EmptyHeteroGraph();
static_cast<dmlc::Stream*>(&ofs)->Read(gptr);

EXPECT_EQ(gptr->NumVertices(0), 9);
EXPECT_EQ(gptr->NumVertices(1), 8);
}

0 comments on commit 23893bb

Please sign in to comment.