-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Support serialization for smart pointer (dmlc#1291)
* fix script * t * fix weird bugs * fix * fix * upload * fix * fix * lint * fix
- Loading branch information
Showing
3 changed files
with
160 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
/*! | ||
* Copyright (c) 2017 by Contributors | ||
* \file dgl/runtime/serializer.h | ||
* \brief Serializer extension to support DGL data types | ||
* Include this file to enable serialization of DLDataType, DLContext | ||
*/ | ||
#ifndef DGL_RUNTIME_SMART_PTR_SERIALIZER_H_ | ||
#define DGL_RUNTIME_SMART_PTR_SERIALIZER_H_ | ||
|
||
#include <dmlc/io.h> | ||
#include <dmlc/serializer.h> | ||
|
||
namespace dmlc { | ||
namespace serializer { | ||
|
||
//! \cond Doxygen_Suppress | ||
template <typename T> | ||
struct Handler<std::shared_ptr<T>> { | ||
inline static void Write(Stream *strm, const std::shared_ptr<T> &data) { | ||
Handler<T>::Write(strm, *data.get()); | ||
} | ||
inline static bool Read(Stream *strm, std::shared_ptr<T> *data) { | ||
// When read, the default initialization behavior of shared_ptr is | ||
// shared_ptr<T>(), which is holding a nullptr. Here we need to manually | ||
// reset to a real object for further loading | ||
if (!(*data)) { | ||
data->reset(new T()); | ||
} | ||
return Handler<T>::Read(strm, data->get()); | ||
} | ||
}; | ||
|
||
template <typename T> | ||
struct Handler<std::unique_ptr<T>> { | ||
inline static void Write(Stream *strm, const std::unique_ptr<T> &data) { | ||
Handler<T>::Write(strm, *data.get()); | ||
} | ||
inline static bool Read(Stream *strm, std::unique_ptr<T> *data) { | ||
// When read, the default initialization behavior of unique_ptr is | ||
// unique_ptr<T>(), which is holding a nullptr. Here we need to manually | ||
// reset to a real object for further loading | ||
if (!(*data)) { | ||
data->reset(new T()); | ||
} | ||
return Handler<T>::Read(strm, data->get()); | ||
} | ||
}; | ||
|
||
} // namespace serializer | ||
} // namespace dmlc | ||
#endif // DGL_RUNTIME_SMART_PTR_SERIALIZER_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
#include <dgl/runtime/serializer.h> | ||
#include <dgl/runtime/smart_ptr_serializer.h> | ||
#include <dmlc/io.h> | ||
#include <dmlc/logging.h> | ||
#include <dmlc/memory_io.h> | ||
#include <dmlc/parameter.h> | ||
#include <gtest/gtest.h> | ||
#include <cstring> | ||
#include <iostream> | ||
#include <sstream> | ||
#include <unordered_map> | ||
|
||
using namespace std; | ||
|
||
class MyClass { | ||
public: | ||
MyClass() {} | ||
MyClass(std::string data) : data_(data) {} | ||
inline void Save(dmlc::Stream *strm) const { strm->Write(this->data_); } | ||
inline bool Load(dmlc::Stream *strm) { return strm->Read(&data_); } | ||
inline bool operator==(const MyClass &other) const { | ||
return data_ == other.data_; | ||
} | ||
|
||
public: | ||
std::string data_; | ||
}; | ||
// need to declare the traits property of my class to dmlc | ||
namespace dmlc { | ||
DMLC_DECLARE_TRAITS(has_saveload, MyClass, true); | ||
} | ||
|
||
template <typename T> | ||
class SmartPtrTest : public ::testing::Test { | ||
public: | ||
typedef T SmartPtr; | ||
}; | ||
|
||
using SmartPtrTypes = | ||
::testing::Types<std::shared_ptr<MyClass>, std::unique_ptr<MyClass>>; | ||
TYPED_TEST_SUITE(SmartPtrTest, SmartPtrTypes); | ||
|
||
TYPED_TEST(SmartPtrTest, Obj_Test) { | ||
std::string blob; | ||
dmlc::MemoryStringStream fs(&blob); | ||
using SmartPtr = typename TestFixture::SmartPtr; | ||
auto myc = SmartPtr(new MyClass("1111")); | ||
{ static_cast<dmlc::Stream *>(&fs)->Write(myc); } | ||
fs.Seek(0); | ||
auto copy_data = SmartPtr(new MyClass()); | ||
CHECK(static_cast<dmlc::Stream *>(&fs)->Read(©_data)); | ||
|
||
EXPECT_EQ(myc->data_, copy_data->data_); | ||
} | ||
|
||
TYPED_TEST(SmartPtrTest, Vector_Test1) { | ||
std::string blob; | ||
dmlc::MemoryStringStream fs(&blob); | ||
using SmartPtr = typename TestFixture::SmartPtr; | ||
typedef std::pair<std::string, SmartPtr> Pair; | ||
auto my1 = SmartPtr(new MyClass("@A@")); | ||
auto my2 = SmartPtr(new MyClass("2222")); | ||
|
||
std::vector<Pair> myclasses; | ||
myclasses.emplace_back("a", SmartPtr(new MyClass("@A@B"))); | ||
myclasses.emplace_back("b", SmartPtr(new MyClass("2222"))); | ||
static_cast<dmlc::Stream *>(&fs)->Write<std::vector<Pair>>(myclasses); | ||
|
||
dmlc::MemoryStringStream ofs(&blob); | ||
std::vector<Pair> copy_myclasses; | ||
static_cast<dmlc::Stream *>(&ofs)->Read<std::vector<Pair>>(©_myclasses); | ||
|
||
EXPECT_TRUE(std::equal(myclasses.begin(), myclasses.end(), | ||
copy_myclasses.begin(), | ||
[](const Pair &left, const Pair &right) { | ||
return (left.second->data_ == right.second->data_) && | ||
(left.first == right.first); | ||
})); | ||
} | ||
|
||
TYPED_TEST(SmartPtrTest, Vector_Test2) { | ||
std::string blob; | ||
dmlc::MemoryStringStream fs(&blob); | ||
using SmartPtr = typename TestFixture::SmartPtr; | ||
auto my1 = SmartPtr(new MyClass("@A@")); | ||
auto my2 = SmartPtr(new MyClass("2222")); | ||
|
||
std::vector<SmartPtr> myclasses; | ||
myclasses.emplace_back(new MyClass("@A@")); | ||
myclasses.emplace_back(new MyClass("2222")); | ||
static_cast<dmlc::Stream *>(&fs)->Write<std::vector<SmartPtr>>(myclasses); | ||
|
||
dmlc::MemoryStringStream ofs(&blob); | ||
std::vector<SmartPtr> copy_myclasses; | ||
static_cast<dmlc::Stream *>(&ofs)->Read<std::vector<SmartPtr>>( | ||
©_myclasses); | ||
|
||
EXPECT_TRUE(std::equal(myclasses.begin(), myclasses.end(), | ||
copy_myclasses.begin(), | ||
[](const SmartPtr &left, const SmartPtr &right) { | ||
return left->data_ == right->data_; | ||
})); | ||
} |