Skip to content

Commit

Permalink
Begin to document TensorBase methods (pytorch#466)
Browse files Browse the repository at this point in the history
  • Loading branch information
colesbury authored and soumith committed Jan 18, 2017
1 parent 90fe6dd commit a09f653
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 30 deletions.
2 changes: 0 additions & 2 deletions docs/source/tensors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ view of a storage and defines numeric operations on it.
.. autoattribute:: is_cuda
:annotation:
.. automethod:: is_pinned
.. automethod:: is_same_size
.. automethod:: is_set_to
.. automethod:: is_signed
.. automethod:: kthvalue
Expand Down Expand Up @@ -310,4 +309,3 @@ view of a storage and defines numeric operations on it.
.. automethod:: view
.. automethod:: view_as
.. automethod:: zero_
.. automethod:: zeros_
164 changes: 150 additions & 14 deletions torch/_tensor_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@

add_docstr(torch._C.FloatTensorBase.apply_,
"""
apply_(callable) -> Tensor
Applies the function :attr:`callable` to each element in the tensor, replacing
each element with the value returned by :attr:`callable`.
""")

add_docstr(torch._C.FloatTensorBase.asin,
Expand Down Expand Up @@ -212,6 +216,13 @@

add_docstr(torch._C.FloatTensorBase.cauchy_,
"""
cauchy_(generator=None, median=0, sigma=1) -> Tensor
Fills the tensor with numbers drawn from the Cauchy distribution:
.. math::
P(x) = \dfrac{1}{\pi} \dfrac{\sigma}{(x - median)^2 + \sigma^2}
""")

add_docstr(torch._C.FloatTensorBase.ceil,
Expand Down Expand Up @@ -244,14 +255,34 @@

add_docstr(torch._C.FloatTensorBase.clone,
"""
clone() -> Tensor
Returns a copy of the tensor. The copy has the same size and data type as the
original tensor.
""")

add_docstr(torch._C.FloatTensorBase.contiguous,
"""
contiguous() -> Tensor
Returns a contiguous Tensor containing the same data as this tensor. If this
tensor is contiguous, this function returns the original tensor.
""")

add_docstr(torch._C.FloatTensorBase.copy_,
"""
copy_(src, async=False) -> Tensor
Copies the elements from :attr:`src` into this tensor and returns this tensor.
The source tensor should have the same number of elements as this tensor. It
may be of a different data type or reside on a different device.
Args:
src (Tensor): Source tensor to copy
async (bool): If True and this copy is between CPU and GPU, then the copy
may occur asynchronously with respect to the host. For other
copies, this argument has no effect.
""")

add_docstr(torch._C.FloatTensorBase.cos,
Expand Down Expand Up @@ -305,6 +336,9 @@

add_docstr(torch._C.FloatTensorBase.data_ptr,
"""
data_ptr() -> int
Returns the address of the first element of this tensor.
""")

add_docstr(torch._C.FloatTensorBase.diag,
Expand All @@ -316,6 +350,9 @@

add_docstr(torch._C.FloatTensorBase.dim,
"""
dim() -> int
Returns the number of dimensions of this tensor.
""")

add_docstr(torch._C.FloatTensorBase.dist,
Expand Down Expand Up @@ -355,6 +392,15 @@

add_docstr(torch._C.FloatTensorBase.element_size,
"""
element_size() -> int
Returns the size in bytes of an individual element.
Example:
>>> torch.FloatTensor().element_size()
4
>>> torch.ByteTensor().element_size()
1
""")

add_docstr(torch._C.FloatTensorBase.eq,
Expand Down Expand Up @@ -394,10 +440,20 @@

add_docstr(torch._C.FloatTensorBase.exponential_,
"""
exponential_(generator=None, lambd=1) -> Tensor
Fills this tensor with elements drawn from the exponential distribution:
.. math::
P(x) = \lambda e^{-\lambda x}
""")

add_docstr(torch._C.FloatTensorBase.fill_,
"""
fill_(value) -> Tensor
Fills this tensor with the specified value.
""")

add_docstr(torch._C.FloatTensorBase.floor,
Expand Down Expand Up @@ -472,26 +528,33 @@

add_docstr(torch._C.FloatTensorBase.geometric_,
"""
geometric_(generator=None, p) -> Tensor
Fills this tensor with elements drawn from the geometric distribution:
.. math::
P(X=k) = (1 - p)^{k - 1} p
""")

add_docstr(torch._C.FloatTensorBase.geqrf,
"""
geqrf() -> (Tensor, Tensor)
TODO: fix signature
See :func:`torch.geqrf`
""")

add_docstr(torch._C.FloatTensorBase.ger,
"""
ger(vec2) -> Tensor
See :func:`torch.ger`
""")

add_docstr(torch._C.FloatTensorBase.gesv,
"""
gesv(A) -> Tensor, Tensor
See :func:`torch.gesv`
""")
Expand Down Expand Up @@ -519,18 +582,86 @@

add_docstr(torch._C.FloatTensorBase.index,
"""
index(m) -> Tensor
Selects elements from this tensor using a binary mask or along a given
dimension. The expression ``tensor.index(m)`` is equivalent to ``tensor[m]``.
Args:
m (int or ByteTensor or slice): The dimension or mask used to select elements
""")

add_docstr(torch._C.FloatTensorBase.index_add_,
"""
index_add_(dim, index, tensor) -> Tensor
Accumulate the elements of tensor into the original tensor by adding to the
indices in the order given in index. The shape of tensor must exactly match the
elements indexed or an error will be raised.
Args:
dim (int): Dimension along which to index
index (LongTensor): Indices to select from tensor
tensor (Tensor): Tensor containing values to add
Example:
>>> x = torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]])
>>> t = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> index = torch.LongTensor([0, 2, 1])
>>> x.index_add_(0, index, t)
>>> x
2 3 4
8 9 10
5 6 7
[torch.FloatTensor of size 3x3]
""")

add_docstr(torch._C.FloatTensorBase.index_copy_,
"""
index_copy_(dim, index, tensor) -> Tensor
Copies the elements of tensor into the original tensor by selecting the
indices in the order given in index. The shape of tensor must exactly match the
elements indexed or an error will be raised.
Args:
dim (int): Dimension along which to index
index (LongTensor): Indices to select from tensor
tensor (Tensor): Tensor containing values to copy
Example:
>>> x = torch.Tensor(3, 3)
>>> t = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> index = torch.LongTensor([0, 2, 1])
>>> x.index_copy_(0, index, t)
>>> x
1 2 3
7 8 9
4 5 6
[torch.FloatTensor of size 3x3]
""")

add_docstr(torch._C.FloatTensorBase.index_fill_,
"""
index_fill_(dim, index, tensor) -> Tensor
Fills the elements of the original tensor with value :attr:`val` by selecting
the indices in the order given in index.
Args:
dim (int): Dimension along which to index
index (LongTensor): Indices
val (float): Value to fill
Example:
>>> x = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> index = torch.LongTensor([0, 2])
>>> x.index_fill_(1, index, -1)
>>> x
-1 2 -1
-1 5 -1
-1 8 -1
[torch.FloatTensor of size 3x3]
""")

add_docstr(torch._C.FloatTensorBase.index_select,
Expand All @@ -542,17 +673,16 @@

add_docstr(torch._C.FloatTensorBase.inverse,
"""
inverse() -> Tensor
See :func:`torch.inverse`
""")

add_docstr(torch._C.FloatTensorBase.is_contiguous,
"""
""")
is_contiguous() -> bool
add_docstr(torch._C.FloatTensorBase.is_same_size,
"""
Returns True if this tensor is contiguous in memory in C order.
""")

add_docstr(torch._C.FloatTensorBase.is_set_to,
Expand Down Expand Up @@ -1111,14 +1241,14 @@

add_docstr(torch._C.FloatTensorBase.svd,
"""
svd(some=True) -> (Tensor, Tensor, Tensor)
See :func:`torch.svd`
""")

add_docstr(torch._C.FloatTensorBase.symeig,
"""
symeig(eigenvectors=False, upper=True) -> (Tensor, Tensor)
See :func:`torch.symeig`
""")
Expand Down Expand Up @@ -1223,7 +1353,7 @@

add_docstr(torch._C.FloatTensorBase.trtrs,
"""
trtrs(A, upper=True, transpose=False, unitriangular=False) -> (Tensor, Tensor)
See :func:`torch.trtrs`
""")
Expand All @@ -1244,13 +1374,20 @@

add_docstr(torch._C.FloatTensorBase.unfold,
"""
unfold(dim, size, step) -> Tensor
See :func:`torch.unfold`
""")

add_docstr(torch._C.FloatTensorBase.uniform_,
"""
uniform_(from=0, to=1) -> Tensor
Fills this tensor with numbers sampled from the uniform distribution:
.. math:
P(x) = \dfrac{1}{to - from}
""")

add_docstr(torch._C.FloatTensorBase.var,
Expand All @@ -1262,8 +1399,7 @@

add_docstr(torch._C.FloatTensorBase.zero_,
"""
""")
zero_()
add_docstr(torch._C.FloatTensorBase.zeros_,
"""
Fills this tensor with zeros.
""")
10 changes: 0 additions & 10 deletions torch/csrc/generic/methods/Tensor.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,6 @@ PyObject * THPTensor_(setIndex)(THPTensor *self, PyObject *args)
long_args: True
]]

[[
name: zeros_
cname: zeros
return: self
arguments:
- THTensor* self
- arg: THSize* size
long_args: True
]]

[[
name: ones
only_stateless: True
Expand Down
8 changes: 4 additions & 4 deletions torch/csrc/generic/methods/TensorRandom.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,9 @@ static void THTensor_(normal_means_stddevs)(THTensor *self, THGenerator *gen, TH
- THTensor* self
- arg: THGenerator* generator
default: THPDefaultGenerator->cdata
- arg: real location
- arg: real median
default: 0
- arg: real scale
- arg: real sigma
default: 1
]]

Expand Down Expand Up @@ -335,9 +335,9 @@ static void THTensor_(normal_means_stddevs)(THCState *_, THTensor *self, THTenso
return: self
arguments:
- THTensor* self
- arg: double location
- arg: double median
default: 0
- arg: double scale
- arg: double sigma
default: 1
]]

Expand Down

0 comments on commit a09f653

Please sign in to comment.