Skip to content

Commit 12d5126

Browse files
committed
[master] add TorchApplyRecorderMixin
1 parent 81d61cd commit 12d5126

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

jactorch/nn/simple.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
# This file is part of Jacinle.
99
# Distributed under terms of the MIT license.
1010

11+
import torch
1112
import torch.nn as nn
1213

13-
__all__ = ['Identity']
14+
__all__ = ['Identity', 'TorchApplyRecorderMixin']
1415

1516

1617
class Identity(nn.Module):
@@ -19,3 +20,19 @@ def forward(self, *args):
1920
return args[0]
2021
return args
2122

23+
24+
class TorchApplyRecorderMixin(nn.Module):
25+
def __init__(self):
26+
super().__init__()
27+
self._apply_recorder_indicator = nn.Parameter(
28+
torch.tensor(0, dtype=torch.float32, device=torch.device('cpu'))
29+
)
30+
self._apply_recorder_indicator.requires_grad = False
31+
32+
@property
33+
def dtype(self):
34+
return self._apply_recorder_indicator.dtype
35+
36+
@property
37+
def device(self):
38+
return self._apply_recorder_indicator.device

0 commit comments

Comments
 (0)