Commit 12d5126 1 parent 81d61cd commit 12d5126 Copy full SHA for 12d5126
File tree 1 file changed +18
-1
lines changed
1 file changed +18
-1
lines changed Original file line number Diff line number Diff line change 8
8
# This file is part of Jacinle.
9
9
# Distributed under terms of the MIT license.
10
10
11
+ import torch
11
12
import torch .nn as nn
12
13
13
- __all__ = ['Identity' ]
14
+ __all__ = ['Identity' , 'TorchApplyRecorderMixin' ]
14
15
15
16
16
17
class Identity (nn .Module ):
@@ -19,3 +20,19 @@ def forward(self, *args):
19
20
return args [0 ]
20
21
return args
21
22
23
+
24
+ class TorchApplyRecorderMixin (nn .Module ):
25
+ def __init__ (self ):
26
+ super ().__init__ ()
27
+ self ._apply_recorder_indicator = nn .Parameter (
28
+ torch .tensor (0 , dtype = torch .float32 , device = torch .device ('cpu' ))
29
+ )
30
+ self ._apply_recorder_indicator .requires_grad = False
31
+
32
+ @property
33
+ def dtype (self ):
34
+ return self ._apply_recorder_indicator .dtype
35
+
36
+ @property
37
+ def device (self ):
38
+ return self ._apply_recorder_indicator .device
You can’t perform that action at this time.
0 commit comments