-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathcgemm_fermi_kernels.h
194 lines (141 loc) · 3.72 KB
/
cgemm_fermi_kernels.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
#ifndef CGEMM_FERMI_KERNELS_H
#define CGEMM_FERMI_KERNELS_H
/*
-- MAGMA (version 2.1.0) --
Univ. of Tennessee, Knoxville
Univ. of California, Berkeley
Univ. of Colorado, Denver
@date August 2016
@author Jakub Kurzak
@author Stan Tomov
@author Mark Gates
See [zcds]gemm_fermi.cu for description of related files.
*/
#include "magma_internal.h"
// =============================================================================
#define COMPLEX
//#undef DOUBLE
#define TEXTURE_1D
#include "gemm_stencil_defs.h"
// =============================================================================
// currently, CPU driver assumes all transpose versions have same DIM_X, DIM_Y
// size of thread block for calculating C (innermost loop)
#define DIM_X 16
#define DIM_Y 16
// =============================================================================
// A x B
// size of work for a thread block
#define BLK_M_nn 64
#define BLK_N_nn 64
#define BLK_K 16
// size of thread block for reading A (dev->regs->shmem)
#define DIM_XA 32
#define DIM_YA 8
// size of thread block for reading B (dev->regs->shmem)
#define DIM_XB 16
#define DIM_YB 16
#undef version
#define version trans_nn
#include "gemm_stencil.cuh"
#include "gemm_kernel.cuh"
//#undef BLK_M
//#undef BLK_N
#undef BLK_K
#undef DIM_XA
#undef DIM_YA
#undef DIM_XB
#undef DIM_YB
// =============================================================================
// A x B^T
// size of work for a thread block
#define BLK_M_nt 64
#define BLK_N_nt 64
#define BLK_M_nc 64
#define BLK_N_nc 64
#define BLK_K 16
// size of thread block for reading A (dev->regs->shmem)
#define DIM_XA 16
#define DIM_YA 16
// size of thread block for reading B (dev->regs->shmem)
#define DIM_XB 16
#define DIM_YB 16
#undef version
#define version trans_nt
#include "gemm_stencil.cuh"
#include "gemm_kernel.cuh"
#undef version
#define version trans_nc
#include "gemm_stencil.cuh"
#include "gemm_kernel.cuh"
//#undef BLK_M
//#undef BLK_N
#undef BLK_K
#undef DIM_XA
#undef DIM_YA
#undef DIM_XB
#undef DIM_YB
// =============================================================================
// A^T x B^T
// size of work for a thread block
#define BLK_M_tt 64
#define BLK_N_tt 64
#define BLK_M_tc 64
#define BLK_N_tc 64
#define BLK_M_ct 64
#define BLK_N_ct 64
#define BLK_M_cc 64
#define BLK_N_cc 64
#define BLK_K 16
// size of thread block for reading A (dev->regs->shmem)
#define DIM_XA 16
#define DIM_YA 16
// size of thread block for reading B (dev->regs->shmem)
#define DIM_XB 32
#define DIM_YB 8
#undef version
#define version trans_tt
#include "gemm_stencil.cuh"
#include "gemm_kernel.cuh"
#undef version
#define version trans_tc
#include "gemm_stencil.cuh"
#include "gemm_kernel.cuh"
#undef version
#define version trans_ct
#include "gemm_stencil.cuh"
#include "gemm_kernel.cuh"
#undef version
#define version trans_cc
#include "gemm_stencil.cuh"
#include "gemm_kernel.cuh"
//#undef BLK_M
//#undef BLK_N
#undef BLK_K
#undef DIM_XA
#undef DIM_YA
#undef DIM_XB
#undef DIM_YB
// =============================================================================
// A^T x B
// size of work for a thread block
#define BLK_M_tn 64
#define BLK_N_tn 64
#define BLK_M_cn 64
#define BLK_N_cn 64
#define BLK_K 16
// size of thread block for reading A (dev->regs->shmem)
#define DIM_XA 16
#define DIM_YA 16
// size of thread block for reading B (dev->regs->shmem)
#define DIM_XB 16
#define DIM_YB 16
#undef version
#define version trans_tn
#include "gemm_stencil.cuh"
#include "gemm_kernel.cuh"
#undef version
#define version trans_cn
#include "gemm_stencil.cuh"
#include "gemm_kernel.cuh"
#undef COMPLEX
#endif // #ifndef CGEMM_FERMI_KERNELS_H