Skip to content

Commit

Permalink
[Eager] Allow set dynamic attribute to eager tensor instance (PaddleP…
Browse files Browse the repository at this point in the history
  • Loading branch information
SigureMo authored Dec 19, 2024
1 parent ab1043a commit 9622b41
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 0 deletions.
3 changes: 3 additions & 0 deletions paddle/fluid/pybind/eager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1460,6 +1460,7 @@ static void TensorDealloc(TensorObject* self) {
if (self->weakrefs != nullptr)
PyObject_ClearWeakRefs(reinterpret_cast<PyObject*>(self));
self->tensor.~Tensor();
Py_XDECREF(self->dict);
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
}

Expand Down Expand Up @@ -1502,6 +1503,7 @@ void BindEager(pybind11::module* module) {
type->tp_base = reinterpret_cast<PyTypeObject*>(&PyBaseObject_Type);
type->tp_flags |=
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; // NOLINT
type->tp_dictoffset = offsetof(TensorObject, dict);
#if PY_VERSION_HEX >= 0x03050000
type->tp_as_async = &heap_type->as_async;
#endif
Expand Down Expand Up @@ -1550,6 +1552,7 @@ void BindEagerStringTensor(pybind11::module* module) {
type->tp_base = reinterpret_cast<PyTypeObject*>(&PyBaseObject_Type);
type->tp_flags |=
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; // NOLINT
type->tp_dictoffset = offsetof(TensorObject, dict);
#if PY_VERSION_HEX >= 0x03050000
type->tp_as_async = &heap_type->as_async;
#endif
Expand Down
12 changes: 12 additions & 0 deletions paddle/fluid/pybind/eager_properties.cc
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,16 @@ PyObject* tensor_properties_get_grad_fn(TensorObject* self, void* closure) {
EAGER_CATCH_AND_THROW_RETURN_NULL
}

PyObject* tensor_properties___dict__(TensorObject* self, void*) {
EAGER_TRY
if (self->dict == nullptr) {
self->dict = PyDict_New();
}
Py_INCREF(self->dict);
return self->dict;
EAGER_CATCH_AND_THROW_RETURN_NULL
}

struct PyGetSetDef variable_properties[] = { // NOLINT
{"data",
(getter)tensor_properties_get_data,
Expand Down Expand Up @@ -1036,6 +1046,7 @@ struct PyGetSetDef variable_properties[] = { // NOLINT
nullptr,
nullptr,
nullptr},
{"__dict__", (getter)tensor_properties___dict__, nullptr, nullptr, nullptr},
{nullptr, nullptr, nullptr, nullptr, nullptr}};

// variable_properties for core.eager.StringTensor
Expand All @@ -1053,6 +1064,7 @@ struct PyGetSetDef string_tensor_variable_properties[] = { // NOLINT
nullptr,
nullptr,
nullptr},
{"__dict__", (getter)tensor_properties___dict__, nullptr, nullptr, nullptr},
{nullptr, nullptr, nullptr, nullptr, nullptr}};

} // namespace pybind
Expand Down
2 changes: 2 additions & 0 deletions paddle/utils/pybind.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ namespace pybind {

typedef struct {
PyObject_HEAD paddle::Tensor tensor;
// Dynamic attributes
PyObject* dict;
// Weak references
PyObject* weakrefs;
} TensorObject;
Expand Down
1 change: 1 addition & 0 deletions test/dygraph_to_static/test_tensor_attr_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
'__len__',
'__long__',
'__nonzero__',
'__dict__',
'apply_',
'backward',
'clear_grad',
Expand Down
19 changes: 19 additions & 0 deletions test/legacy_test/test_eager_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1867,5 +1867,24 @@ def test_same_place_data_ptr_consistency(self):
self.assertEqual(x.data_ptr(), y.data_ptr())


class TestSetDynamicAttributeToEagerTensorInstance(unittest.TestCase):
def test_set_dynamic_attribute_to_eager_tensor_instance_create_via_constructor(
self,
):
tensor_instance = paddle.to_tensor(1.0)
tensor_instance._custom_id = 0
self.assertEqual(tensor_instance._custom_id, 0)
self.assertEqual(tensor_instance.__dict__["_custom_id"], 0)

def test_set_dynamic_attribute_to_eager_tensor_instance_create_via_to_pyobject(
self,
):
original_tensor = paddle.to_tensor(-1.0)
tensor_instance = paddle.abs(original_tensor)
tensor_instance._custom_flag = True
self.assertEqual(tensor_instance._custom_flag, True)
self.assertEqual(tensor_instance.__dict__["_custom_flag"], True)


if __name__ == "__main__":
unittest.main()

0 comments on commit 9622b41

Please sign in to comment.