Skip to content

Commit

Permalink
[Feature] Support serialization for smart pointer (dmlc#1291)
Browse files Browse the repository at this point in the history
* fix script

* t

* fix weird bugs

* fix

* fix

* upload

* fix

* fix

* lint

* fix
  • Loading branch information
VoVAllen authored Feb 26, 2020
1 parent 37d992e commit ebca118
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 6 deletions.
12 changes: 6 additions & 6 deletions include/dgl/runtime/serializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,29 @@
namespace dmlc {
namespace serializer {

template<>
template <>
struct Handler<DLDataType> {
inline static void Write(Stream *strm, const DLDataType& dtype) {
inline static void Write(Stream *strm, const DLDataType &dtype) {
Handler<uint8_t>::Write(strm, dtype.code);
Handler<uint8_t>::Write(strm, dtype.bits);
Handler<uint16_t>::Write(strm, dtype.lanes);
}
inline static bool Read(Stream *strm, DLDataType* dtype) {
inline static bool Read(Stream *strm, DLDataType *dtype) {
if (!Handler<uint8_t>::Read(strm, &(dtype->code))) return false;
if (!Handler<uint8_t>::Read(strm, &(dtype->bits))) return false;
if (!Handler<uint16_t>::Read(strm, &(dtype->lanes))) return false;
return true;
}
};

template<>
template <>
struct Handler<DLContext> {
inline static void Write(Stream *strm, const DLContext& ctx) {
inline static void Write(Stream *strm, const DLContext &ctx) {
int32_t device_type = static_cast<int32_t>(ctx.device_type);
Handler<int32_t>::Write(strm, device_type);
Handler<int32_t>::Write(strm, ctx.device_id);
}
inline static bool Read(Stream *strm, DLContext* ctx) {
inline static bool Read(Stream *strm, DLContext *ctx) {
int32_t device_type = 0;
if (!Handler<int32_t>::Read(strm, &(device_type))) return false;
ctx->device_type = static_cast<DLDeviceType>(device_type);
Expand Down
51 changes: 51 additions & 0 deletions include/dgl/runtime/smart_ptr_serializer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*!
* Copyright (c) 2017 by Contributors
* \file dgl/runtime/serializer.h
* \brief Serializer extension to support DGL data types
* Include this file to enable serialization of DLDataType, DLContext
*/
#ifndef DGL_RUNTIME_SMART_PTR_SERIALIZER_H_
#define DGL_RUNTIME_SMART_PTR_SERIALIZER_H_

#include <dmlc/io.h>
#include <dmlc/serializer.h>

namespace dmlc {
namespace serializer {

//! \cond Doxygen_Suppress
template <typename T>
struct Handler<std::shared_ptr<T>> {
inline static void Write(Stream *strm, const std::shared_ptr<T> &data) {
Handler<T>::Write(strm, *data.get());
}
inline static bool Read(Stream *strm, std::shared_ptr<T> *data) {
// When read, the default initialization behavior of shared_ptr is
// shared_ptr<T>(), which is holding a nullptr. Here we need to manually
// reset to a real object for further loading
if (!(*data)) {
data->reset(new T());
}
return Handler<T>::Read(strm, data->get());
}
};

template <typename T>
struct Handler<std::unique_ptr<T>> {
inline static void Write(Stream *strm, const std::unique_ptr<T> &data) {
Handler<T>::Write(strm, *data.get());
}
inline static bool Read(Stream *strm, std::unique_ptr<T> *data) {
// When read, the default initialization behavior of unique_ptr is
// unique_ptr<T>(), which is holding a nullptr. Here we need to manually
// reset to a real object for further loading
if (!(*data)) {
data->reset(new T());
}
return Handler<T>::Read(strm, data->get());
}
};

} // namespace serializer
} // namespace dmlc
#endif // DGL_RUNTIME_SMART_PTR_SERIALIZER_H_
103 changes: 103 additions & 0 deletions tests/cpp/test_smart_ptr_serialize.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#include <dgl/runtime/serializer.h>
#include <dgl/runtime/smart_ptr_serializer.h>
#include <dmlc/io.h>
#include <dmlc/logging.h>
#include <dmlc/memory_io.h>
#include <dmlc/parameter.h>
#include <gtest/gtest.h>
#include <cstring>
#include <iostream>
#include <sstream>
#include <unordered_map>

using namespace std;

class MyClass {
public:
MyClass() {}
MyClass(std::string data) : data_(data) {}
inline void Save(dmlc::Stream *strm) const { strm->Write(this->data_); }
inline bool Load(dmlc::Stream *strm) { return strm->Read(&data_); }
inline bool operator==(const MyClass &other) const {
return data_ == other.data_;
}

public:
std::string data_;
};
// need to declare the traits property of my class to dmlc
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, MyClass, true);
}

template <typename T>
class SmartPtrTest : public ::testing::Test {
public:
typedef T SmartPtr;
};

using SmartPtrTypes =
::testing::Types<std::shared_ptr<MyClass>, std::unique_ptr<MyClass>>;
TYPED_TEST_SUITE(SmartPtrTest, SmartPtrTypes);

TYPED_TEST(SmartPtrTest, Obj_Test) {
std::string blob;
dmlc::MemoryStringStream fs(&blob);
using SmartPtr = typename TestFixture::SmartPtr;
auto myc = SmartPtr(new MyClass("1111"));
{ static_cast<dmlc::Stream *>(&fs)->Write(myc); }
fs.Seek(0);
auto copy_data = SmartPtr(new MyClass());
CHECK(static_cast<dmlc::Stream *>(&fs)->Read(&copy_data));

EXPECT_EQ(myc->data_, copy_data->data_);
}

TYPED_TEST(SmartPtrTest, Vector_Test1) {
std::string blob;
dmlc::MemoryStringStream fs(&blob);
using SmartPtr = typename TestFixture::SmartPtr;
typedef std::pair<std::string, SmartPtr> Pair;
auto my1 = SmartPtr(new MyClass("@A@"));
auto my2 = SmartPtr(new MyClass("2222"));

std::vector<Pair> myclasses;
myclasses.emplace_back("a", SmartPtr(new MyClass("@A@B")));
myclasses.emplace_back("b", SmartPtr(new MyClass("2222")));
static_cast<dmlc::Stream *>(&fs)->Write<std::vector<Pair>>(myclasses);

dmlc::MemoryStringStream ofs(&blob);
std::vector<Pair> copy_myclasses;
static_cast<dmlc::Stream *>(&ofs)->Read<std::vector<Pair>>(&copy_myclasses);

EXPECT_TRUE(std::equal(myclasses.begin(), myclasses.end(),
copy_myclasses.begin(),
[](const Pair &left, const Pair &right) {
return (left.second->data_ == right.second->data_) &&
(left.first == right.first);
}));
}

TYPED_TEST(SmartPtrTest, Vector_Test2) {
std::string blob;
dmlc::MemoryStringStream fs(&blob);
using SmartPtr = typename TestFixture::SmartPtr;
auto my1 = SmartPtr(new MyClass("@A@"));
auto my2 = SmartPtr(new MyClass("2222"));

std::vector<SmartPtr> myclasses;
myclasses.emplace_back(new MyClass("@A@"));
myclasses.emplace_back(new MyClass("2222"));
static_cast<dmlc::Stream *>(&fs)->Write<std::vector<SmartPtr>>(myclasses);

dmlc::MemoryStringStream ofs(&blob);
std::vector<SmartPtr> copy_myclasses;
static_cast<dmlc::Stream *>(&ofs)->Read<std::vector<SmartPtr>>(
&copy_myclasses);

EXPECT_TRUE(std::equal(myclasses.begin(), myclasses.end(),
copy_myclasses.begin(),
[](const SmartPtr &left, const SmartPtr &right) {
return left->data_ == right->data_;
}));
}

0 comments on commit ebca118

Please sign in to comment.