-
Notifications
You must be signed in to change notification settings - Fork 681
/
Copy pathhelpers.py
66 lines (48 loc) · 1.67 KB
/
helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from io import BytesIO
from itertools import product
import random
from typing import Any
import torch
test_dims_rng = random.Random(42)
TRUE_FALSE = (True, False)
BOOLEAN_TRIPLES = list(product(TRUE_FALSE, repeat=3)) # all combinations of (bool, bool, bool)
BOOLEAN_TUPLES = list(product(TRUE_FALSE, repeat=2)) # all combinations of (bool, bool)
def torch_save_to_buffer(obj):
buffer = BytesIO()
torch.save(obj, buffer)
buffer.seek(0)
return buffer
def torch_load_from_buffer(buffer):
buffer.seek(0)
obj = torch.load(buffer, weights_only=False)
buffer.seek(0)
return obj
def get_test_dims(min: int, max: int, *, n: int) -> list[int]:
return [test_dims_rng.randint(min, max) for _ in range(n)]
def format_with_label(label: str, value: Any) -> str:
if isinstance(value, bool):
formatted = "T" if value else "F"
elif isinstance(value, (list, tuple)) and all(isinstance(v, bool) for v in value):
formatted = "".join("T" if b else "F" for b in value)
elif isinstance(value, torch.dtype):
formatted = describe_dtype(value)
else:
formatted = str(value)
return f"{label}={formatted}"
def id_formatter(label: str):
"""
Return a function that formats the value given to it with the given label.
"""
return lambda value: format_with_label(label, value)
DTYPE_NAMES = {
torch.bfloat16: "bf16",
torch.bool: "bool",
torch.float16: "fp16",
torch.float32: "fp32",
torch.float64: "fp64",
torch.int32: "int32",
torch.int64: "int64",
torch.int8: "int8",
}
def describe_dtype(dtype: torch.dtype) -> str:
return DTYPE_NAMES.get(dtype) or str(dtype).rpartition(".")[2]