Skip to content

Commit

Permalink
Fix type annotations for torch.sparse, enable in CI (pytorch#43108)
Browse files Browse the repository at this point in the history
Summary:
Closes pytorchgh-42982

Pull Request resolved: pytorch#43108

Reviewed By: malfet

Differential Revision: D23167560

Pulled By: ezyang

fbshipit-source-id: 0d660ca686ada2347bf440c6349551d1539f99ef
  • Loading branch information
rgommers authored and facebook-github-bot committed Aug 17, 2020
1 parent 6db0b87 commit 864f0cf
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 12 deletions.
3 changes: 0 additions & 3 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,6 @@ ignore_errors = True
[mypy-torch.jit.*]
ignore_errors = True

[mypy-torch.sparse]
ignore_errors = True

[mypy-torch.tensor]
ignore_errors = True

Expand Down
19 changes: 10 additions & 9 deletions torch/sparse/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
# The Tensor classes are added to this module by python_tensor.cpp
from typing import Optional, Tuple
from typing import Optional, Tuple, List, Union

import torch
from torch import Tensor

# A workaround to support both TorchScript and MyPy:
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from torch import dtype as DType
from torch.types import _dtype as DType
DimOrDims = Optional[Union[int, Tuple[int], List[int]]]
else:
# The JIT doesn't understand Union, nor torch.dtype here
DType = int
# TODO: replace the above with
# from torch.types import _dtype as DType
DimOrDims = Optional[Tuple[int]]


__all__ = [
Expand All @@ -23,8 +24,8 @@
]


def addmm(mat, mat1, mat2, beta=1, alpha=1):
# type: (Tensor, Tensor, Tensor, float, float) -> Tensor
def addmm(mat: Tensor, mat1: Tensor, mat2: Tensor,
beta: float = 1., alpha: float = 1.) -> Tensor:
r"""
This function does exact same thing as :func:`torch.addmm` in the forward,
except that it supports backward for sparse matrix :attr:`mat1`. :attr:`mat1`
Expand All @@ -41,7 +42,7 @@ def addmm(mat, mat1, mat2, beta=1, alpha=1):
return torch._sparse_addmm(mat, mat1, mat2, beta=beta, alpha=alpha)


def mm(mat1, mat2):
def mm(mat1: Tensor, mat2: Tensor) -> Tensor:
r"""
Performs a matrix multiplication of the sparse matrix :attr:`mat1`
and dense matrix :attr:`mat2`. Similar to :func:`torch.mm`, If :attr:`mat1` is a
Expand Down Expand Up @@ -83,8 +84,8 @@ def mm(mat1, mat2):
return torch._sparse_mm(mat1, mat2)


def sum(input, dim=None, dtype=None):
# type: (Tensor, Optional[Tuple[int]], Optional[int]) -> Tensor
def sum(input: Tensor, dim: DimOrDims = None,
dtype: Optional[DType] = None) -> Tensor:
r"""
Returns the sum of each row of SparseTensor :attr:`input` in the given
dimensions :attr:`dim`. If :attr:`dim` is a list of dimensions,
Expand Down

0 comments on commit 864f0cf

Please sign in to comment.