-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathdtypes.py
114 lines (89 loc) · 3.13 KB
/
dtypes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from __future__ import annotations
import ctypes
from dataclasses import dataclass
from .defines import CType
_python_bool = bool
@dataclass(frozen=True)
class Dtype:
name: str
typecode: str
c_type: CType
typename: str
c_api_value: int # Internal use only
def __str__(self) -> str:
return self.name
def __repr__(self) -> str:
return f"arrayfire.{self.name}(typecode<{self.typecode}>)"
# Specification required
int16 = s16 = Dtype("int16", "h", ctypes.c_short, "short int", 10)
int32 = s32 = Dtype("int32", "i", ctypes.c_int, "int", 5)
int64 = s64 = Dtype("int64", "l", ctypes.c_longlong, "long int", 8)
uint8 = u8 = Dtype("uint8", "B", ctypes.c_ubyte, "unsigned_char", 7)
uint16 = u16 = Dtype("uint16", "H", ctypes.c_ushort, "unsigned short int", 11)
uint32 = u32 = Dtype("uint32", "I", ctypes.c_uint, "unsigned int", 6)
uint64 = u64 = Dtype("uint64", "L", ctypes.c_ulonglong, "unsigned long int", 9)
float16 = f16 = Dtype("float16", "e", ctypes.c_uint16, "half", 12)
float32 = f32 = Dtype("float32", "f", ctypes.c_float, "float", 0)
float64 = f64 = Dtype("float64", "d", ctypes.c_double, "double", 2)
complex32 = c32 = Dtype("complex64", "F", ctypes.c_float * 2, "float complex", 1) # type: ignore[arg-type]
complex64 = c64 = Dtype("complex128", "D", ctypes.c_double * 2, "double complex", 3) # type: ignore[arg-type]
bool = b8 = Dtype("bool", "b", ctypes.c_bool, "bool", 4)
supported_dtypes = (
int16,
int32,
int64,
uint8,
uint16,
uint32,
uint64,
float16,
float32,
float64,
complex64,
complex32,
bool,
s16,
s32,
s64,
u8,
u16,
u32,
u64,
f16,
f32,
f64,
c32,
c64,
b8,
)
def to_str(c_str: ctypes.c_char_p | ctypes.Array[ctypes.c_char]) -> str:
return str(c_str.value.decode("utf-8")) # type: ignore[union-attr]
def implicit_dtype(number: int | float | _python_bool | complex, array_dtype: Dtype) -> Dtype:
if isinstance(number, _python_bool):
number_dtype = bool
elif isinstance(number, int):
number_dtype = int64
elif isinstance(number, float):
number_dtype = float32
elif isinstance(number, complex):
number_dtype = complex64
else:
raise TypeError(f"{type(number)} is not supported and can not be converted to af.Dtype.")
if array_dtype not in supported_dtypes:
raise ValueError(f"{array_dtype} is not in supported dtypes.")
_a32 = array_dtype == float32 or array_dtype == complex32
if number_dtype == float64 and _a32:
return float32
if number_dtype == complex64 and _a32:
return complex64
return number_dtype
def c_api_value_to_dtype(value: int) -> Dtype:
for dtype in supported_dtypes:
if value == dtype.c_api_value:
return dtype
raise TypeError("There is no supported dtype that matches passed dtype C API value.")
def str_to_dtype(value: str) -> Dtype:
for dtype in supported_dtypes:
if value == dtype.typecode or value == dtype.typename or value == dtype.name:
return dtype
raise TypeError("There is no supported dtype that matches passed dtype typecode.")