forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Relax use_count constraints for swap_tensors when AccumulateGrad hold…
…s a reference (pytorch#127313) ### Before this PR: `torch.utils.swap_tensors(a, b)` required the `use_count` of `a` and `b` to be 1 ```python a = torch.randn(2, 3, requires_grad=True) b = torch.randn(2, 4) out = a * 2 out.sum().backward() # Calling swap_tensors here would fail due to the reference held by AccumulateGrad node, which is not cleaned up after backward # torch.utils.swap_tensors(a, b) del out # Calling swap_tensors here would pass torch.utils.swap_tensors(a, b) ``` ### After this PR: `torch.utils.swap_tensors(a, b)` requires the `use_count` of `a` and `b` to be 1 or 2 IF the second reference is held by `AccumulateGrad` A pre-hook will be registered on the `AccumulateGrad` node so that it will fail if it is called (i.e. if user attempts to backward through the graph). ```python a = torch.randn(2, 3, requires_grad=True) b = torch.randn(2, 4) out = a * 2 out.sum().backward() # Calling swap_tensors here is ok torch.utils.swap_tensors(a, b) # If we ever backward to the AccumulateGrad node it will error that it was poisoned by swap_tensors ``` ### Application to `nn.Module` This issue is especially pertinent in context of `nn.Module` where parameters will have `AccumulateGrad` nodes initialized after forward. Specifically, this is intended to address pytorch#126814 (comment). Previously, this would fail at the `m.cpu()` but we want users to be able to do something like the following, and instead raise an error if the user ever attempts to backward through the poisoned `AccumulateGrad` node ```python import torch import torch.nn as nn m = nn.Linear(3, 5) inp = torch.randn(2, 3) out = m(inp) out.sum().backward() m.cpu() ``` Pull Request resolved: pytorch#127313 Approved by: https://github.com/soulitzer
- Loading branch information
1 parent
d44ab8b
commit cd06ae0
Showing
7 changed files
with
70 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters