Skip to content

Commit

Permalink
[misc] Improve ::liong::json::deserialize()
Browse files Browse the repository at this point in the history
ghstack-source-id: c8e3a80a232a5731038d0309e7a05784b198b645
Pull Request resolved: taichi-dev#7789
  • Loading branch information
PGZXB authored and Taichi Gardener committed Apr 14, 2023
1 parent 43f38b1 commit b78e28d
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 47 deletions.
2 changes: 1 addition & 1 deletion taichi/common/json.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class JsonException : public std::exception {
std::string msg_;

public:
explicit JsonException(const char *msg) : msg_(msg) {
explicit JsonException(std::string_view msg) : msg_(msg) {
}
const char *what() const noexcept override {
return msg_.c_str();
Expand Down
105 changes: 69 additions & 36 deletions taichi/common/json_serde.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ struct has_ptr_serde {
static constexpr auto helper(T_ *) -> std::is_same<
decltype((T_::jsonserde_ptr_io(std::declval<const T_ *&>(),
std::declval<JsonValue &>(),
std::declval<bool>(),
std::declval<bool>()))),
void>;

Expand Down Expand Up @@ -118,7 +119,9 @@ struct JsonSerde {
T> &x) {
JsonValue val;
using T_ = std::remove_pointer_t<T>;
T_::jsonserde_ptr_io((const T_ *&)x, val, /*writing=*/true);
// NOTE: strict is not used if writing is true
T_::jsonserde_ptr_io((const T_ *&)x, val, /*writing=*/true,
/*strict=*/false);
return val;
}

Expand All @@ -128,9 +131,11 @@ struct JsonSerde {
typename std::enable_if_t<
std::is_pointer_v<U> &&
has_ptr_serde<std::remove_pointer_t<U>>::value,
T> &x) {
T> &x,
bool strict) {
using T_ = std::remove_pointer_t<T>;
T_::jsonserde_ptr_io((const T_ *&)x, get_writable(j), /*writing=*/false);
T_::jsonserde_ptr_io((const T_ *&)x, get_writable(j), /*writing=*/false,
/*strict=*/strict);
}

// Numeric and boolean types (integers and floating-point numbers).
Expand All @@ -142,7 +147,8 @@ struct JsonSerde {
template <typename U = typename std::remove_cv<T>::type>
static void deserialize(
const JsonValue &j,
typename std::enable_if_t<std::is_arithmetic<U>::value, T> &x) {
typename std::enable_if_t<std::is_arithmetic<U>::value, T> &x,
bool strict) {
x = (T)j;
}
template <typename U = typename std::remove_cv<T>::type>
Expand All @@ -153,7 +159,8 @@ struct JsonSerde {
template <typename U = typename std::remove_cv<T>::type>
static void deserialize(
const JsonValue &j,
typename std::enable_if_t<std::is_enum<U>::value, T> &x) {
typename std::enable_if_t<std::is_enum<U>::value, T> &x,
bool strict) {
x = (T)(typename std::underlying_type<T>::type)j;
}

Expand All @@ -166,7 +173,8 @@ struct JsonSerde {
template <typename U = typename std::remove_cv<T>::type>
static void deserialize(
const JsonValue &j,
typename std::enable_if_t<std::is_same<U, std::string>::value, T> &x) {
typename std::enable_if_t<std::is_same<U, std::string>::value, T> &x,
bool strict) {
x = (T)j;
}

Expand All @@ -184,10 +192,12 @@ struct JsonSerde {
const JsonValue &j,
typename std::enable_if_t<
std::is_same<decltype(std::declval<U>().json_deserialize_fields(
std::declval<const JsonObject &>())),
std::declval<const JsonObject &>(),
std::declval<bool>())),
void>::value,
T> &x) {
x.json_deserialize_fields((const JsonObject &)j);
T> &x,
bool strict) {
x.json_deserialize_fields((const JsonObject &)j, strict);
}

// Key-value pairs.
Expand All @@ -210,9 +220,11 @@ struct JsonSerde {
typename std::enable_if_t<std::is_same<std::pair<typename U::first_type,
typename U::second_type>,
T>::value,
T> &x) {
JsonSerde<typename T::first_type>::deserialize(j["key"], x.first);
JsonSerde<typename T::second_type>::deserialize(j["value"], x.second);
T> &x,
bool strict) {
JsonSerde<typename T::first_type>::deserialize(j["key"], x.first, strict);
JsonSerde<typename T::second_type>::deserialize(j["value"], x.second,
strict);
}

// Owned pointer (requires default constructable).
Expand All @@ -232,12 +244,13 @@ struct JsonSerde {
const JsonValue &j,
typename std::enable_if_t<
std::is_same<std::unique_ptr<typename U::element_type>, T>::value,
T> &x) {
T> &x,
bool strict) {
if (j.is_null()) {
x = nullptr;
} else {
x = std::make_unique<typename T::element_type>();
JsonSerde<typename T::element_type>::deserialize(j, *x);
JsonSerde<typename T::element_type>::deserialize(j, *x, strict);
}
}

Expand Down Expand Up @@ -278,9 +291,11 @@ struct JsonSerde {
template <typename U = typename std::remove_cv<T>::type>
static void deserialize(
const JsonValue &j,
typename std::enable_if_t<std::is_array<U>::value, T> &x) {
typename std::enable_if_t<std::is_array<U>::value, T> &x,
bool strict) {
for (size_t i = 0; i < std::extent<T>::value; ++i) {
JsonSerde<typename std::remove_extent_t<T>>::deserialize(j[i], x[i]);
JsonSerde<typename std::remove_extent_t<T>>::deserialize(j[i], x[i],
strict);
}
}
template <typename U = typename std::remove_cv<T>::type>
Expand All @@ -290,21 +305,23 @@ struct JsonSerde {
std::is_same<
std::array<typename U::value_type, std::tuple_size<U>::value>,
T>::value,
T> &x) {
T> &x,
bool strict) {
for (size_t i = 0; i < x.size(); ++i) {
JsonSerde<typename T::value_type>::deserialize(j[i], x.at(i));
JsonSerde<typename T::value_type>::deserialize(j[i], x.at(i), strict);
}
}
template <typename U = typename std::remove_cv<T>::type>
static void deserialize(
const JsonValue &j,
typename std::enable_if_t<
std::is_same<std::vector<typename U::value_type>, T>::value,
T> &x) {
T> &x,
bool strict) {
x.clear();
for (const auto &elem : j.elems()) {
typename T::value_type xx{};
JsonSerde<decltype(xx)>::deserialize(elem, xx);
JsonSerde<decltype(xx)>::deserialize(elem, xx, strict);
x.emplace_back(std::move(xx));
}
}
Expand Down Expand Up @@ -341,11 +358,12 @@ struct JsonSerde {
typename std::enable_if_t<
std::is_same<std::map<typename U::key_type, typename U::mapped_type>,
T>::value,
T> &x) {
T> &x,
bool strict) {
x.clear();
for (const auto &elem : j.elems()) {
std::pair<typename T::key_type, typename T::mapped_type> xx{};
JsonSerde<decltype(xx)>::deserialize(elem, xx);
JsonSerde<decltype(xx)>::deserialize(elem, xx, strict);
x.emplace(std::move(*(std::pair<const typename T::key_type,
typename T::mapped_type> *)&xx));
}
Expand All @@ -357,11 +375,12 @@ struct JsonSerde {
std::is_same<
std::unordered_map<typename U::key_type, typename U::mapped_type>,
T>::value,
T> &x) {
T> &x,
bool strict) {
x.clear();
for (const auto &elem : j.elems()) {
std::pair<typename T::key_type, typename T::mapped_type> xx{};
JsonSerde<decltype(xx)>::deserialize(elem, xx);
JsonSerde<decltype(xx)>::deserialize(elem, xx, strict);
x.emplace(std::move(*(std::pair<const typename T::key_type,
typename T::mapped_type> *)&xx));
}
Expand All @@ -384,12 +403,13 @@ struct JsonSerde {
const JsonValue &j,
typename std::enable_if_t<
std::is_same<std::optional<typename U::value_type>, T>::value,
T> &x) {
T> &x,
bool strict) {
if (j.is_null()) {
x = std::nullopt;
} else {
typename T::value_type xx;
JsonSerde<typename T::value_type>::deserialize(j, xx);
JsonSerde<typename T::value_type>::deserialize(j, xx, strict);
x = std::move(xx);
}
}
Expand All @@ -408,24 +428,29 @@ struct JsonSerdeFieldImpl<TFirst, TOthers...> {
JsonSerdeFieldImpl<TOthers...>::serialize(obj, ++name, others...);
}
inline static void deserialize(const JsonObject &obj,
bool strict,
std::vector<std::string>::const_iterator name,
TFirst &first,
TOthers &...others) {
auto it = obj.inner.find(*name);
if (it != obj.inner.end()) {
JsonSerde<TFirst>::deserialize(it->second, first);
JsonSerde<TFirst>::deserialize(it->second, first, strict);
} else if (strict) {
throw ::liong::json::JsonException("Missing field: " + *name);
}
JsonSerdeFieldImpl<TOthers...>::deserialize(obj, ++name, others...);
JsonSerdeFieldImpl<TOthers...>::deserialize(obj, strict, ++name, others...);
}
};
template <>
struct JsonSerdeFieldImpl<> {
inline static void serialize(JsonObject &obj,
std::vector<std::string>::const_iterator name) {
}
inline static void deserialize(
inline static bool deserialize(
const JsonObject &obj,
bool strict,
std::vector<std::string>::const_iterator name) {
return true;
}
};
template <typename... TArgs>
Expand All @@ -438,9 +463,13 @@ inline void json_serialize_field_impl(
template <typename... TArgs>
inline void json_deserialize_field_impl(
const JsonObject &obj,
bool strict,
std::vector<std::string>::const_iterator name,
TArgs &...args) {
JsonSerdeFieldImpl<TArgs...>::deserialize(obj, name, args...);
if (strict && obj.inner.size() != sizeof...(TArgs)) {
throw ::liong::json::JsonException("unexpected number of fields");
}
return JsonSerdeFieldImpl<TArgs...>::deserialize(obj, strict, name, args...);
}

} // namespace detail
Expand All @@ -453,10 +482,13 @@ JsonValue serialize(const T &x) {
}

// Deserialize a JSON serde object, turning JSON text into in-memory
// representations.
// representations. If `strict` is true, the function will throw JsonException
// if a field is missing or an extra field is present. Otherwise, the missing
// fields will be filled with default values and the extra fields will be
// ignored. See serialize_test.cpp for examples.
template <typename T>
void deserialize(const JsonValue &j, T &out) {
detail::JsonSerde<T>::deserialize(j, out);
void deserialize(const JsonValue &j, T &out, bool strict = false) {
detail::JsonSerde<T>::deserialize(j, out, strict);
}

// If you need to control the serialization process on your own, you might want
Expand All @@ -466,7 +498,7 @@ struct CustomJsonSerdeBase {
// Serialize the field values into a JSON object.
virtual JsonObject json_serialize_fields() const = 0;
// Deserialize the current object with JSON fields.
virtual void json_deserialize_fields(const JsonObject &j) = 0;
virtual void json_deserialize_fields(const JsonObject &j, bool strict) = 0;
};

} // namespace json
Expand All @@ -484,7 +516,8 @@ struct CustomJsonSerdeBase {
out, json_serde_field_names().begin(), __VA_ARGS__); \
return ::liong::json::JsonValue(std::move(out)); \
} \
void json_deserialize_fields(const ::liong::json::JsonObject &j) { \
void json_deserialize_fields(const ::liong::json::JsonObject &j, \
bool strict) { \
::liong::json::detail::json_deserialize_field_impl( \
j, json_serde_field_names().begin(), __VA_ARGS__); \
j, strict, json_serde_field_names().begin(), __VA_ARGS__); \
}
23 changes: 13 additions & 10 deletions taichi/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class TI_DLL_EXPORT Type {

template <typename T>
static typename std::enable_if<std::is_base_of_v<Type, T>, void>::type
jsonserde_ptr_io(const T *&ptr, JsonValue &value, bool writing);
jsonserde_ptr_io(const T *&ptr, JsonValue &value, bool writing, bool strict);

// For serialization
virtual const Type *get_type() const = 0;
Expand Down Expand Up @@ -682,7 +682,10 @@ Type::ptr_io(const T *&ptr, S &serializer, bool writing) {

template <typename T>
typename std::enable_if<std::is_base_of_v<Type, T>, void>::type
Type::jsonserde_ptr_io(const T *&ptr, JsonValue &value, bool writing) {
Type::jsonserde_ptr_io(const T *&ptr,
JsonValue &value,
bool writing,
bool strict) {
if (writing) {
if (ptr == nullptr) {
value = JsonValue(nullptr);
Expand Down Expand Up @@ -713,14 +716,14 @@ Type::jsonserde_ptr_io(const T *&ptr, JsonValue &value, bool writing) {
}
TypeKind type_kind = (TypeKind)(int)value["type_kind"];
switch (type_kind) {
#define PER_TYPE_KIND(x) \
case TypeKind::x: { \
x##Type content; \
auto &content_val = value["content"]; \
TI_ASSERT(content_val.is_obj()); \
content.json_deserialize_fields(content_val.obj); \
ptr = content.get_type()->as<T>(); \
break; \
#define PER_TYPE_KIND(x) \
case TypeKind::x: { \
x##Type content; \
auto &content_val = value["content"]; \
TI_ASSERT(content_val.is_obj()); \
content.json_deserialize_fields(content_val.obj, strict); \
ptr = content.get_type()->as<T>(); \
break; \
}
#include "taichi/inc/type_kind.inc.h"
#undef PER_TYPE_KIND
Expand Down
53 changes: 53 additions & 0 deletions tests/cpp/common/serialization_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,5 +210,58 @@ TEST(SERIALIZATION, Type) {
EXPECT_EQ(deserialized3->to_string(), quant_array_type->to_string());
}

struct Foo {
std::string k;
int v{-1};

bool operator==(const Foo &other) const {
return k == other.k && v == other.v;
}

TI_IO_DEF(k, v);
};

TEST(Serialization, JsonSerde) {
using namespace ::liong::json;

const auto kCorrectJson = R"({"k":"hello","v":42})";
const auto kWrongFieldNameJson = R"({"k":"hello","value":42})";
const auto kWrongFieldTypeJson = R"({"k":"hello","v":"42"})";
const auto kMissingFieldJson = R"({"k":"hello"})";
const auto kExtraFieldJson = R"({"k":"hello","v":42,"extra":1})";

Foo foo, t;
foo.k = "hello";
foo.v = 42;

// Serialize
EXPECT_EQ(kCorrectJson, print(serialize(foo)));

// Deserialize (correct)
deserialize(parse(kCorrectJson), t, true);
EXPECT_EQ(foo, t);

// Deserialize (wrong, on strict mode)
EXPECT_THROW(deserialize(parse(kWrongFieldNameJson), t, true), JsonException);
EXPECT_THROW(deserialize(parse(kWrongFieldTypeJson), t, true), JsonException);
EXPECT_THROW(deserialize(parse(kMissingFieldJson), t, true), JsonException);
EXPECT_THROW(deserialize(parse(kExtraFieldJson), t, true), JsonException);

// Deserialize (wrong, but on non-strict mode)
t = Foo{};
deserialize(parse(kWrongFieldNameJson), t, false); // no exception
EXPECT_EQ(foo.k, t.k);
EXPECT_EQ(-1, t.v); // default value

t = Foo{};
deserialize(parse(kMissingFieldJson), t, false); // no exception
EXPECT_EQ(foo.k, t.k);
EXPECT_EQ(-1, t.v); // default value

t = Foo{};
deserialize(parse(kExtraFieldJson), t, false); // no exception
EXPECT_EQ(foo, t);
}

} // namespace
} // namespace taichi::lang

0 comments on commit b78e28d

Please sign in to comment.