Skip to content

Commit

Permalink
Add tensor.repeat docs. Remove legacy tensor repeat function. (pytorc…
Browse files Browse the repository at this point in the history
…h#5666)

* Add tensor.repeat docs. Remove legacy tensor repeat function.

* Fix nit
  • Loading branch information
zou3519 authored and soumith committed Mar 10, 2018
1 parent b5ee5e5 commit 439aae7
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 51 deletions.
29 changes: 29 additions & 0 deletions torch/_tensor_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,6 +1261,35 @@ def callable(a, b) -> number
In-place version of :meth:`~Tensor.renorm`
""")

add_docstr_all('repeat',
r"""
repeat(*sizes) -> Tensor
Repeats this tensor along the specified dimensions.
Unlike :meth:`~Tensor.expand`, this function copies the tensor’s data.
Args:
sizes (torch.Size or int...): The number of times to repeat this tensor along each
dimension
Example::
>>> x = torch.Tensor([1, 2, 3])
>>> x.repeat(4, 2)
1 2 3 1 2 3
1 2 3 1 2 3
1 2 3 1 2 3
1 2 3 1 2 3
[torch.FloatTensor of size (4,6)]
>>> x.repeat(4, 2, 1).size()
torch.Size([4, 2, 3])
""")

add_docstr_all('resize_',
r"""
resize_(*sizes) -> Tensor
Expand Down
51 changes: 0 additions & 51 deletions torch/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,54 +259,3 @@ def _take_tensors(tensors, size_limit):
for buf, _ in buf_dict.values():
if len(buf) > 0:
yield buf


def _repeat(self, *sizes):
r"""Repeats this tensor along the specified dimensions.
Unlike :meth:`expand`, this function copies the tensor's data.
Args:
*sizes (torch.Size or int...): The number of times to repeat this
tensor along each dimension
Example:
>>> x = torch.Tensor([1, 2, 3])
>>> x.repeat(4, 2)
1 2 3 1 2 3
1 2 3 1 2 3
1 2 3 1 2 3
1 2 3 1 2 3
[torch.FloatTensor of size 4x6]
>>> x.repeat(4, 2, 1).size()
torch.Size([4, 2, 3])
"""
# If args == (torch.Size,), then we need to unpack the tuple
if len(sizes) == 1 and isinstance(sizes[0], torch.Size):
sizes = sizes[0]

repeats = list(sizes)

if len(repeats) < self.dim():
raise ValueError('Number of dimensions of repeat dims can not be '
'smaller than number of dimensions of tensor')

# Add new leading dimensions to the tensor if the
# number of target dimensions is larger than the
# number of source dimensions.
num_new_dimensions = len(repeats) - self.dim()
padded_size = [1] * num_new_dimensions + list(self.size())
target_size = torch.Size([a * b for a, b in zip(padded_size, repeats)])

xtensor = self.new().set_(self)
xtensor = xtensor.expand(padded_size)

result = self.new()
result.resize_(target_size)
urtensor = result.new(result)
for i in range(xtensor.dim()):
urtensor = urtensor.unfold(i, xtensor.size(i), xtensor.size(i))

urtensor.copy_(xtensor.expand_as(urtensor))

return result

0 comments on commit 439aae7

Please sign in to comment.