forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[functorch] Fix torch.cat batching rule (pytorch#86932)
The bug was discovered in pytorch#86842. torch.cat has an edge case where it ignores all tensors of shape [0]. So if any of the BatchedTensors have logical shape [0] but physical shape [B, 0], then we coerce them to shape [0] by slicing them. Why don't we just ignore those Tensors? We need to propagate requires_grad-ness somehow (e.g. if the BatchedTensor wraps a Tensor of shape [B, 0] that requires grad, then the output must require grad). Test Plan: - new tests Pull Request resolved: pytorch#86932 Approved by: https://github.com/Chillee
- Loading branch information
1 parent
c16b7b4
commit b805e1a
Showing
4 changed files
with
75 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters