Skip to content

Commit

Permalink
Simplify ScriptModule bindings. (#29432)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch/pytorch#29432

This removes a lot of the private methods on torch._C.ScriptModule,
and instead implements functionality in terms of slot_dict_impl views
to implement _parameter, _buffers, and _modules in nn.Module.

A followup PR should also remove the _register_attribute,
_register_module, and _register_parameter methods, but this requires
more refactoring of the way tracing creates modules and replication
for data parallel works.

Test Plan: Imported from OSS

Differential Revision: D18387963

Pulled By: zdevito

fbshipit-source-id: f10d47afeb30c1e05d704ae5ac4166830933125c
  • Loading branch information
zdevito authored and facebook-github-bot committed Nov 11, 2019
1 parent 5b702ab commit 4e4e29a
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 159 deletions.
4 changes: 2 additions & 2 deletions test/jit/test_recursive_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,12 @@ def forward(self):

# sm1 was created while m had training = True
self.assertTrue(sm1.training)
self.assertEqual(sm1.training, sm1._c._get_attribute('training'))
self.assertEqual(sm1.training, sm1._c.getattr('training'))
self.assertEqual(sm1(), 2)

# sm2 was created after m was eval'ed
self.assertFalse(sm2.training)
self.assertEqual(sm2.training, sm2._c._get_attribute('training'))
self.assertEqual(sm2.training, sm2._c.getattr('training'))
self.assertEqual(sm2(), 0)

def test_module_name(self):
Expand Down
2 changes: 1 addition & 1 deletion test/jit_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,4 +612,4 @@ def get_forward_graph(c):
return c._get_method('forward').graph

def get_module_method(m, module, method):
return m._c._get_module(module)._get_method(method)
return m._c.getattr(module)._get_method(method)
34 changes: 17 additions & 17 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,15 +940,15 @@ def forward(self, x):
weight=observer._c)
}
torch._C._jit_pass_insert_observers(m._c, "forward", qconfig_dict, True)
assert len([x for x, _ in m._c._get_modules()
assert len([x for x, _ in m._modules._c.items()
if x.startswith('_observer_')]) == 0, \
'Expected to have 0 observer submodules'
FileCheck().check_not('ClassType<Observer> = prim::GetAttr[name="_observer_') \
.check('ClassType<Conv2d> = prim::GetAttr[name="conv"](%self)') \
.check_next('Tensor = prim::CallMethod[name="forward"]') \
.check_not('ClassType<Observer> = prim::GetAttr[name="_observer_') \
.run(str(get_forward_graph(m._c)))
assert len([x for x, _ in m._c._get_module('conv')._get_modules()
assert len([x for x, _ in m.conv._modules._c.items()
if x.startswith('_observer_')]) == 3, \
'Expected to have 3 observer submodules'
FileCheck().check('ClassType<Observer> = prim::GetAttr[name="_observer_') \
Expand All @@ -958,7 +958,7 @@ def forward(self, x):
.check('Tensor = aten::conv2d') \
.check('ClassType<Observer> = prim::GetAttr[name="_observer_') \
.check_next('prim::CallMethod[name="forward"](%_observer_') \
.run(str(m._c._get_module("conv")._get_method('conv2d_forward').graph))
.run(str(m._c.getattr("conv")._get_method('conv2d_forward').graph))

@_tmp_donotuse_dont_inline_everything
def test_insert_observers_child_qconfig(self):
Expand Down Expand Up @@ -1010,13 +1010,13 @@ def check_not_observed(s):
# check m is not observed
check_not_observed(get_forward_graph(m._c))
# check conv.forward is observed
check_not_observed(get_forward_graph(m._c._get_module('conv')))
check_not_observed(get_forward_graph(m._c.getattr('conv')))
# check conv.conv2d_forward is observed
check_observed(get_module_method(m, 'conv', 'conv2d_forward').graph)
# check sub is not observed
check_not_observed(get_module_method(m, 'sub', 'forward'))
# check forward of sub.linear is observed
check_observed(get_forward(m._c._get_module('sub')._get_module('linear')).graph)
check_observed(get_forward(m._c.getattr('sub').getattr('linear')).graph)

@_tmp_donotuse_dont_inline_everything
def test_insert_observers_skip_values(self):
Expand Down Expand Up @@ -1053,7 +1053,7 @@ def test_module(module, relu_call, num_observers):
weight=observer._c)
}
torch._C._jit_pass_insert_observers(m._c, "forward", qconfig_dict, True)
assert len([x for x, _ in m._c._get_modules()
assert len([x for x, _ in m._modules._c.items()
if x.startswith('_observer_')]) == num_observers, \
'Expected to have ' + str(num_observers) + ' observer submodules'
c = FileCheck().check('ClassType<Conv2d> = prim::GetAttr[name="conv"]') \
Expand Down Expand Up @@ -1089,7 +1089,8 @@ def forward(self, x):
weight=weight_observer._c)
}
torch._C._jit_pass_insert_observers(m._c, "forward", qconfig_dict, True)
dtypes = set([obs._get_attribute('dtype') for x, obs in m._c._get_module('conv')._get_modules()
print()
dtypes = set([obs.getattr('dtype') for x, obs in m.conv._modules._c.items()
if x.startswith('_observer_')])
assert len(dtypes) == 2, 'Expected to have 2 different types of dtype'

Expand Down Expand Up @@ -1393,7 +1394,7 @@ def forward(self, x):

m = torch.jit.script(M())
torch._C._jit_pass_fold_quantize(m._c, 'forward')
self.assertTrue(m._c._has_attribute('_quantized_weight'))
self.assertTrue(m._c.hasattr('_quantized_weight'))
FileCheck().check_not('GetAttr[name="weight"]') \
.check('GetAttr[name="_quantized_weight"]') \
.run(m._c._get_method('forward').graph)
Expand Down Expand Up @@ -1453,16 +1454,16 @@ def forward(self, x):
conv_packed_params)
res = get_forward(m._c)(data)
# check attribute and graph
packed_module_list = [x for x, _ in m._c._get_modules()
packed_module_list = [x for x, _ in m._modules._c.items()
if x.startswith('_' + name + '_packed_params_module')]
assert len(packed_module_list) == 1, \
'Expected to have one packed_params_module'
packed_module_name = packed_module_list[0]
# check values
original_w = m._c._get_parameter('weight')
original_w = m.weight
ref_w = torch.quantize_per_tensor(original_w, 0.2, 1, torch.qint8).dequantize()
ref_b = m._c._get_parameter('bias')
w, b = m._c._get_module(packed_module_name)._get_method('_weight_bias')()
ref_b = m.bias
w, b = m._c.getattr(packed_module_name)._get_method('_weight_bias')()
self.assertEqual(ref_w, w.dequantize())
self.assertEqual(ref_b, b)
self.assertEqual(ref_res, res)
Expand Down Expand Up @@ -3618,12 +3619,12 @@ def test_eval_python(self):
def _test(m):
self.assertTrue(m(torch.ones(2, 2)))
self.assertTrue(m.training)
self.assertTrue(m._c._get_attribute('training'))
self.assertTrue(m._c.getattr('training'))

m.eval()

self.assertFalse(m.training)
self.assertFalse(m._c._get_attribute('training'))
self.assertFalse(m._c.getattr('training'))
self.assertFalse(m(torch.ones(2, 2)))

if not PY2:
Expand All @@ -3634,7 +3635,7 @@ def _test(m):
loaded = torch.jit.load(buffer)

self.assertFalse(loaded.training)
self.assertFalse(loaded._c._get_attribute('training'))
self.assertFalse(loaded._c.getattr('training'))

class M(nn.Module):
def __init__(self):
Expand Down Expand Up @@ -3898,7 +3899,7 @@ def replace(e):
elif e is a.foo.b:
return 'B'
elif isinstance(e, torch._C.ScriptModule):
return e._get_attribute('name')
return e.getattr('name')

return e
for k, v in result.items():
Expand Down Expand Up @@ -12124,7 +12125,6 @@ def forward(self, x):
self.assertTrue(imported.ssm.asm._c._has_method('bar'))
self.assertTrue(hasattr(imported.ssm.asm, 'bar'))

self.assertTrue(imported.ssm.asm._c._has_parameter('param'))
self.assertTrue(hasattr(imported.ssm.asm, 'param'))

def test_trace_parameter(self):
Expand Down
103 changes: 34 additions & 69 deletions torch/csrc/jit/script/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,16 +433,6 @@ bool ivalue_tags_match(const Module& lhs, const Module& rhs) {
return true;
}

// used to temporarily implement _has_attribute in the python wrapper
// we should replace these individual functions with direct bindings to the
// _parameters, _modules, and _buffers dictionaries
struct LegacyAttributePolicy {
static bool valid(const ClassTypePtr& typ, size_t i) {
return !detail::ModulePolicy::valid(typ, i) &&
!detail::ParameterPolicy::valid(typ, i);
}
};

// helper used to implement ._parameters, ._buffers, ._modules dicts
// inside of script nn.Module
template <typename Policy>
Expand Down Expand Up @@ -470,48 +460,27 @@ struct slot_dict_impl {
}

void setattr(const std::string& name, py::object value) {
size_t N = module_->type()->getAttributeSlot(name);
const TypePtr& type = module_->type()->getAttribute(N);
module_->setSlot(N, toIValue(std::move(value), type));
const TypePtr& type = module_->type()->getAttribute(name);
script::Module(module_).setattr(name, toIValue(std::move(value), type));
}

py::object getattr(const std::string& name) {
return toPyObject(script::Module(module_).attr(name));
}

static void bind(const py::module& m, const char* name) {
py::class_<slot_dict_impl<Policy>>(m, name)
.def(py::init(
[](Module& m) { return slot_dict_impl<Policy>(m.module_object()); }))
.def("contains", &slot_dict_impl<Policy>::contains)
.def("items", &slot_dict_impl<Policy>::items)
.def("setattr", &slot_dict_impl<Policy>::setattr)
.def("getattr", &slot_dict_impl<Policy>::getattr);
}
private:
script::ModulePtr module_;
};

// helpers to implement _set_parameter, _get_parameter, _has_parameter, etc.
// these can be removed once everything works directly from bound slot_dict
// objects
template <typename Policy>
static void set_generic(
Module& self,
const std::string& name,
py::object value) {
slot_dict_impl<Policy>(self.module_object()).setattr(name, std::move(value));
}

static py::object get_generic(Module& self, const std::string& name) {
return toPyObject(self.attr(name));
}

template <typename Policy>
static py::tuple get_generic_list(Module& self) {
auto the_list = script::slot_list_impl<script::detail::NamedPolicy<Policy>>(
self.module_object(), false, false);
py::tuple result(the_list.size());
auto i = 0;
for (const auto& e : the_list) {
py::tuple r(2);
result[i++] = std::make_tuple(e.name, toPyObject(e.value));
}
return result;
}

template <typename Policy>
static bool has_generic(Module& self, const std::string& name) {
return slot_dict_impl<Policy>(self.module_object()).contains(name);
}

template <typename T>
py::list debugMakeList(const T& list) {
py::list result;
Expand Down Expand Up @@ -555,6 +524,7 @@ static py::dict _jit_debug_module_iterators(Module& module) {
return result;
}


void initJitScriptBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();

Expand Down Expand Up @@ -625,30 +595,24 @@ void initJitScriptBindings(PyObject* module) {
name, unshaped, toIValue(std::move(value), type));
})
.def("_register_module", &Module::register_module)
.def("_register_buffer", &Module::register_buffer)
.def(
"_set_attribute",
"setattr",
[](Module& self, const std::string& name, py::object value) {
auto ivalue =
toIValue(std::move(value), self.type()->getAttribute(name));
TypePtr type = self.type()->getAttribute(name);
TORCH_CHECK(type, "Module has no attribute '", name, "'");
auto ivalue = toIValue(std::move(value), type);
self.setattr(name, ivalue);
})
.def("_set_parameter", &set_generic<detail::ParameterPolicy>)
.def("_get_parameter", &get_generic)
.def("_get_buffer", &get_generic)
.def("_get_attribute", &get_generic)
.def("_get_module", &get_generic)
.def(
"_get_modules",
[](Module& self) {
std::vector<std::pair<std::string, Module>> modules;
for (const NameModule& s : self.named_children()) {
modules.emplace_back(std::make_pair(s.name, s.value));
}
return modules;
"getattr",
[](Module& self, const std::string& name) {
return toPyObject(self.attr(name));
})
.def(
"hasattr",
[](Module& self, const std::string& name) {
return self.hasattr(name);
})
.def("_get_parameters", get_generic_list<script::detail::ParameterPolicy>)
.def("_get_buffers", get_generic_list<script::detail::BufferPolicy>)
.def(
"_replicate_for_data_parallel",
[](Module& module) {
Expand All @@ -658,7 +622,8 @@ void initJitScriptBindings(PyObject* module) {
/*should_mangle*/ true);
ClassTypePtr module_cls = module.module_object()->type();
for (size_t i = 0, N = module_cls->numAttributes(); i < N; ++i) {
if (LegacyAttributePolicy::valid(module_cls, i) &&
if (!detail::ModulePolicy::valid(module_cls, i) &&
!detail::ParameterPolicy::valid(module_cls, i) &&
!detail::BufferPolicy::valid(module_cls, i)) {
replica.register_attribute(
module_cls->getAttributeName(i),
Expand All @@ -668,10 +633,6 @@ void initJitScriptBindings(PyObject* module) {
}
return replica;
})
.def("_has_attribute", has_generic<LegacyAttributePolicy>)
.def("_has_parameter", has_generic<script::detail::ParameterPolicy>)
.def("_has_buffer", has_generic<script::detail::BufferPolicy>)
.def("_has_module", has_generic<script::detail::ModulePolicy>)
.def(
"_has_method",
[](Module& self, const std::string& name) {
Expand Down Expand Up @@ -731,6 +692,10 @@ void initJitScriptBindings(PyObject* module) {
m.clone_method(orig, name);
});

slot_dict_impl<script::detail::ParameterPolicy>::bind(m, "ParameterDict");
slot_dict_impl<script::detail::BufferPolicy>::bind(m, "BufferDict");
slot_dict_impl<script::detail::ModulePolicy>::bind(m, "ModuleDict");

py::class_<ErrorReport, std::shared_ptr<ErrorReport>>(m, "ErrorReport")
.def(py::init<SourceRange>())
.def("what", &ErrorReport::what);
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/script/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ struct TORCH_API Module {
void setattr(const std::string& name, IValue v) {
size_t slot = module_object()->type()->getAttributeSlot(name);
const TypePtr& expected = module_object()->type()->getAttribute(slot);
TORCH_CHECK(expected, "Module has no attribute '", name, "'");
TORCH_CHECK(
v.type()->isSubtypeOf(expected),
"Expected a value of type '",
Expand Down
Loading

0 comments on commit 4e4e29a

Please sign in to comment.