forked from gorgonia/cu
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtensor.go
84 lines (73 loc) · 2.52 KB
/
tensor.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
package cudnn
// #include <cudnn.h>
import "C"
import (
"runtime"
)
type TensorDescriptor struct {
internal C.cudnnTensorDescriptor_t // ptr to struct
// internal data for fast
format TensorFormat
dataType DataType
shape []int // NCHW format for 4-tensors
strides []int
}
func NewTensorDescriptor(format TensorFormat, dt DataType, shape, strides []int) (*TensorDescriptor, error) {
var internal C.cudnnTensorDescriptor_t
if err := result(C.cudnnCreateTensorDescriptor(&internal)); err != nil {
return nil, err
}
retVal := &TensorDescriptor{
internal: internal,
format: format,
dataType: dt,
shape: shape,
strides: strides,
}
runtime.SetFinalizer(retVal, destroyTensor)
if err := retVal.set(internal); err != nil {
return nil, err
}
return retVal, nil
}
func (t *TensorDescriptor) set(internal C.cudnnTensorDescriptor_t) error {
switch len(t.shape) {
case 4:
n, c, h, w := t.shape[0], t.shape[1], t.shape[2], t.shape[3]
if len(t.strides) == 4 {
// use explicit
NStrides, CStrides, HStrides, WStrides := t.strides[0], t.strides[1], t.strides[2], t.strides[3]
res := C.cudnnSetTensor4dDescriptorEx(internal, t.dataType.C(),
C.int(n), C.int(c), C.int(h), C.int(w),
C.int(NStrides), C.int(CStrides), C.int(HStrides), C.int(WStrides),
)
return result(res)
}
// otherwise the strides will be calculated by cudnn
res := C.cudnnSetTensor4dDescriptor(internal, t.format.C(), t.dataType.C(),
C.int(n), C.int(c), C.int(h), C.int(w),
)
return result(res)
default:
if len(t.strides) > 0 {
dimA, dimAManaged := ints2CIntPtr(t.shape)
defer returnManaged(dimAManaged)
strideA, strideAManaged := ints2CIntPtr(t.strides)
defer returnManaged(strideAManaged)
// NO, there is no confusion here. Ex is used to set tensor without strides. Silly nVidia.
res := C.cudnnSetTensorNdDescriptor(internal, t.dataType.C(),
C.int(len(t.shape)), dimA, strideA)
return result(res)
}
dimA, dimAManaged := ints2CIntPtr(t.shape)
defer returnManaged(dimAManaged)
res := C.cudnnSetTensorNdDescriptorEx(internal, t.format.C(), t.dataType.C(),
C.int(len(t.shape)), dimA)
return result(res)
}
}
func (t *TensorDescriptor) Format() TensorFormat { return t.format }
func (t *TensorDescriptor) DataType() DataType { return t.dataType }
func (t *TensorDescriptor) Shape() []int { return cloneShape(t.shape) }
func (t *TensorDescriptor) Strides() []int { return cloneShape(t.strides) }
func destroyTensor(obj *TensorDescriptor) { C.cudnnDestroyTensorDescriptor(obj.internal) }