forked from traveller59/second.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tools.py
56 lines (46 loc) · 1.57 KB
/
tools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import functools
import inspect
import sys
from collections import OrderedDict
import numba
import numpy as np
import torch
def get_pos_to_kw_map(func):
pos_to_kw = {}
fsig = inspect.signature(func)
pos = 0
for name, info in fsig.parameters.items():
if info.kind is info.POSITIONAL_OR_KEYWORD:
pos_to_kw[pos] = name
pos += 1
return pos_to_kw
def get_kw_to_default_map(func):
kw_to_default = {}
fsig = inspect.signature(func)
for name, info in fsig.parameters.items():
if info.kind is info.POSITIONAL_OR_KEYWORD:
if info.default is not info.empty:
kw_to_default[name] = info.default
return kw_to_default
def change_default_args(**kwargs):
def layer_wrapper(layer_class):
class DefaultArgLayer(layer_class):
def __init__(self, *args, **kw):
pos_to_kw = get_pos_to_kw_map(layer_class.__init__)
kw_to_pos = {kw: pos for pos, kw in pos_to_kw.items()}
for key, val in kwargs.items():
if key not in kw and kw_to_pos[key] > len(args):
kw[key] = val
super().__init__(*args, **kw)
return DefaultArgLayer
return layer_wrapper
def torch_to_np_dtype(ttype):
type_map = {
torch.float16: np.dtype(np.float16),
torch.float32: np.dtype(np.float32),
torch.float16: np.dtype(np.float64),
torch.int32: np.dtype(np.int32),
torch.int64: np.dtype(np.int64),
torch.uint8: np.dtype(np.uint8),
}
return type_map[ttype]