Skip to content

Commit

Permalink
Add docstring support to cwrap (pytorch#295)
Browse files Browse the repository at this point in the history
  • Loading branch information
apaszke authored Dec 11, 2016
1 parent 1af9a96 commit 28f0cf6
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 5 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ dist/
torch.egg-info/
*/**/__pycache__
torch/csrc/generic/TensorMethods.cpp
torch/csrc/TensorDocstrings.cpp
torch/csrc/TensorDocstrings.h
torch/lib/*.so*
torch/lib/*.dylib*
torch/lib/*.h
Expand Down
10 changes: 8 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,15 @@ def run(self):
from tools.cwrap.plugins.KwargsPlugin import KwargsPlugin
from tools.cwrap.plugins.NullableArguments import NullableArguments
from tools.cwrap.plugins.CuDNNPlugin import CuDNNPlugin
thp_plugin = THPPlugin()
cwrap('torch/csrc/generic/TensorMethods.cwrap', plugins=[
BoolOption(), THPPlugin(), AutoGPU(condition='IS_CUDA'),
ArgcountSortPlugin(), KwargsPlugin(),
BoolOption(), thp_plugin, AutoGPU(condition='IS_CUDA'),
ArgcountSortPlugin(), KwargsPlugin()
])
with open('torch/csrc/TensorDocstrings.cpp', 'w') as f:
f.write(thp_plugin.generate_docstrings_cpp())
with open('torch/csrc/TensorDocstrings.h', 'w') as f:
f.write(thp_plugin.generate_docstrings_h())
cwrap('torch/csrc/cudnn/cuDNN.cwrap', plugins=[
CuDNNPlugin(), NullableArguments()
])
Expand Down Expand Up @@ -157,6 +162,7 @@ def run(self):
"torch/csrc/Size.cpp",
"torch/csrc/Exceptions.cpp",
"torch/csrc/Tensor.cpp",
"torch/csrc/TensorDocstrings.cpp",
"torch/csrc/Storage.cpp",
"torch/csrc/byte_order.cpp",
"torch/csrc/utils.cpp",
Expand Down
38 changes: 35 additions & 3 deletions tools/cwrap/plugins/THPPlugin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from string import Template
from copy import deepcopy
from . import CWrapPlugin
from itertools import product
from itertools import product, chain
from collections import OrderedDict

class THPPlugin(CWrapPlugin):
Expand Down Expand Up @@ -144,6 +144,7 @@ def _allocate(typename, tmpl, cuda_tmpl=None):
def __init__(self):
self.declarations = []
self.stateless_declarations = []
self.docstrings = []

def get_type_unpack(self, arg, option):
return self.TYPE_UNPACK.get(arg['type'], None)
Expand Down Expand Up @@ -197,6 +198,20 @@ def get_arg_accessor(self, arg, option):
if 'allocate' in arg and arg['allocate']:
return arg['name']

def process_docstrings(self):
for declaration in self.declarations:
docstr = declaration.get('docstring_method')
if docstr is None:
continue
declaration['docstring_content'] = docstr.replace('\n', '\\n')
declaration['docstring_var'] = 'docstr_' + declaration['python_name']
for declaration in self.stateless_declarations:
docstr = declaration.get('docstring_stateless')
if docstr is None:
continue
declaration['docstring_content'] = docstr.replace('\n', '\\n')
declaration['docstring_var'] = 'stateless_docstr_' + declaration['python_name']

def process_declarations(self, declarations):
new_declarations = []
register_only = [d for d in declarations if d.get('only_register', False)]
Expand Down Expand Up @@ -248,6 +263,8 @@ def has_long_args(declaration):
self.declarations.extend(filter(lambda x: not x.get('only_stateless', False), register_only))
self.stateless_declarations.extend(filter(lambda x: x.get('only_stateless', False), register_only))

self.process_docstrings()

all_declarations = declarations + new_declarations
return all_declarations

Expand Down Expand Up @@ -296,8 +313,9 @@ def declare_methods(self, stateless):
flags += ' | METH_KEYWORDS'
if declaration.get('override_method_flags'):
flags = declaration['override_method_flags']
entry = Template(' {"$python_name", (PyCFunction)$name, $flags, NULL},\n').substitute(
python_name=declaration['python_name'], name=declaration['name'], flags=flags
entry = Template(' {"$python_name", (PyCFunction)$name, $flags, $docstring},\n').substitute(
python_name=declaration['python_name'], name=declaration['name'], flags=flags,
docstring=declaration.get('docstring_var', 'NULL')
)
if 'defined_if' in declaration:
entry = self.preprocessor_guard(entry, declaration['defined_if'])
Expand Down Expand Up @@ -332,3 +350,17 @@ def process_option_code_template(self, template, option):
new_args.append(self.ALLOCATE_TYPE[arg['type']].substitute(name=arg['name']))
template = new_args + template
return template

def generate_docstrings_cpp(self):
template = Template('char* $name = "$content";')
return '\n\n'.join(
template.substitute(name=decl['docstring_var'], content=decl['docstring_content'])
for decl in chain(self.declarations, self.stateless_declarations)
if 'docstring_var' in decl)

def generate_docstrings_h(self):
template = Template('extern char* $name;')
return '\n\n'.join(
template.substitute(name=decl['docstring_var'])
for decl in chain(self.declarations, self.stateless_declarations)
if 'docstring_var' in decl)
2 changes: 2 additions & 0 deletions torch/csrc/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include "cudnn/Module.h"
#endif

#include "TensorDocstrings.h"

#define WITH_NUMPY_IMPORT_ARRAY
#include "THP.h"

Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,7 @@
#include "THP.h"
#include "copy_utils.h"

#include "TensorDocstrings.h"

#include "generic/Tensor.cpp"
#include <TH/THGenerateAllTypes.h>

0 comments on commit 28f0cf6

Please sign in to comment.