Skip to content

Commit

Permalink
Fix segmentation fault when calling ray.put on a dictionary with obje…
Browse files Browse the repository at this point in the history
…ct keys (ray-project#548)

* fix segfault when serializing dict key

* fix style

* fix test

* Fix linting.
  • Loading branch information
ericl authored and pcmoritz committed May 15, 2017
1 parent 3c53753 commit e2e9e4c
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# The build output should clearly not be checked in
/python/ray/core
/src/common/thirdparty/redis
/numbuf/thirdparty/arrow
/src/numbuf/thirdparty/arrow

# Files generated by flatc should be ignored
/src/common/format/*.py
Expand Down
7 changes: 4 additions & 3 deletions src/numbuf/cpp/src/numbuf/dict.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ using namespace arrow;
namespace numbuf {

Status DictBuilder::Finish(std::shared_ptr<Array> key_tuple_data,
std::shared_ptr<Array> val_list_data, std::shared_ptr<Array> val_tuple_data,
std::shared_ptr<Array> val_dict_data, std::shared_ptr<arrow::Array>* out) {
std::shared_ptr<Array> key_dict_data, std::shared_ptr<Array> val_list_data,
std::shared_ptr<Array> val_tuple_data, std::shared_ptr<Array> val_dict_data,
std::shared_ptr<arrow::Array>* out) {
// lists and dicts can't be keys of dicts in Python, that is why for
// the keys we do not need to collect sublists
std::shared_ptr<Array> keys, vals;
RETURN_NOT_OK(keys_.Finish(nullptr, key_tuple_data, nullptr, &keys));
RETURN_NOT_OK(keys_.Finish(nullptr, key_tuple_data, key_dict_data, &keys));
RETURN_NOT_OK(vals_.Finish(val_list_data, val_tuple_data, val_dict_data, &vals));
auto keys_field = std::make_shared<Field>("keys", keys->type());
auto vals_field = std::make_shared<Field>("vals", vals->type());
Expand Down
1 change: 1 addition & 0 deletions src/numbuf/cpp/src/numbuf/dict.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class DictBuilder {
value list of the dictionary
*/
arrow::Status Finish(std::shared_ptr<arrow::Array> key_tuple_data,
std::shared_ptr<arrow::Array> key_dict_data,
std::shared_ptr<arrow::Array> val_list_data,
std::shared_ptr<arrow::Array> val_tuple_data,
std::shared_ptr<arrow::Array> val_dict_data, std::shared_ptr<arrow::Array>* out);
Expand Down
13 changes: 10 additions & 3 deletions src/numbuf/python/src/pynumbuf/adapters/python.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,13 @@ Status SerializeDict(std::vector<PyObject*> dicts, int32_t recursion_depth,
"This object exceeds the maximum recursion depth. It may contain itself "
"recursively.");
}
std::vector<PyObject *> key_tuples, val_lists, val_tuples, val_dicts, dummy;
std::vector<PyObject *> key_tuples, key_dicts, val_lists, val_tuples, val_dicts, dummy;
for (const auto& dict : dicts) {
PyObject *key, *value;
Py_ssize_t pos = 0;
while (PyDict_Next(dict, &pos, &key, &value)) {
RETURN_NOT_OK(append(key, result.keys(), dummy, key_tuples, dummy, tensors_out));
RETURN_NOT_OK(
append(key, result.keys(), dummy, key_tuples, key_dicts, tensors_out));
DCHECK(dummy.size() == 0);
RETURN_NOT_OK(
append(value, result.vals(), val_lists, val_tuples, val_dicts, tensors_out));
Expand All @@ -245,6 +246,11 @@ Status SerializeDict(std::vector<PyObject*> dicts, int32_t recursion_depth,
RETURN_NOT_OK(SerializeSequences(
key_tuples, recursion_depth + 1, &key_tuples_arr, tensors_out));
}
std::shared_ptr<Array> key_dicts_arr;
if (key_dicts.size() > 0) {
RETURN_NOT_OK(
SerializeDict(key_dicts, recursion_depth + 1, &key_dicts_arr, tensors_out));
}
std::shared_ptr<Array> val_list_arr;
if (val_lists.size() > 0) {
RETURN_NOT_OK(
Expand All @@ -260,7 +266,8 @@ Status SerializeDict(std::vector<PyObject*> dicts, int32_t recursion_depth,
RETURN_NOT_OK(
SerializeDict(val_dicts, recursion_depth + 1, &val_dict_arr, tensors_out));
}
result.Finish(key_tuples_arr, val_list_arr, val_tuples_arr, val_dict_arr, out);
result.Finish(
key_tuples_arr, key_dicts_arr, val_list_arr, val_tuples_arr, val_dict_arr, out);

// This block is used to decrement the reference counts of the results
// returned by the serialization callback, which is called in SerializeArray
Expand Down
13 changes: 10 additions & 3 deletions test/runtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,14 @@ def assert_equal(obj1, obj2):


class Foo(object):
def __init__(self):
pass
def __init__(self, value=0):
self.value = value

def __hash__(self):
return hash(self.value)

def __eq__(self, other):
return other.value == self.value


class Bar(object):
Expand Down Expand Up @@ -139,7 +145,8 @@ class CustomError(Exception):
DICT_OBJECTS = ([{obj: obj} for obj in PRIMITIVE_OBJECTS
if (obj.__hash__ is not None and
type(obj).__module__ != "numpy")] +
[{0: obj} for obj in BASE_OBJECTS])
[{0: obj} for obj in BASE_OBJECTS] +
[{Foo(123): Foo(456)}])

RAY_TEST_OBJECTS = BASE_OBJECTS + LIST_OBJECTS + TUPLE_OBJECTS + DICT_OBJECTS

Expand Down

0 comments on commit e2e9e4c

Please sign in to comment.