forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path_backward_state.py
27 lines (21 loc) · 967 Bytes
/
_backward_state.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch.fx
class BackwardState:
"""
BackwardState is used to pass Python hooks from the forwards pass
into the backwards pass in Dynamo+Compiled Autograd.
It is created by TorchDynamo and has special handling there.
Dynamo will pass an empty BackwardState to the forwards, then populate
members on it (via setattr) only after the forwards graph is finished.
Later on, in CompileAutograd we will inline and add the needed guards
on the BackwardState.
BackwardState is identified and has special handling in AOTAutograd.
During AOTAutograd:
1) BackwardState is an input to the forwards graph
2) It must only be used in the backwards
3) It will be empty in the forwards
4) In the forwards we add a wrapper to save it
5) In the backwards it becomes an input
6) There can only be one per graph
BackwardState requires CompiledAutograd.
"""
proxy: torch.fx.Proxy