Skip to content

Commit

Permalink
tailSquare double-wide: halve the LDS used
Browse files Browse the repository at this point in the history
  • Loading branch information
preda committed Dec 10, 2024
1 parent ea5fa54 commit 1ba0bbe
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 62 deletions.
31 changes: 19 additions & 12 deletions src/cl/base.cl
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ typedef global const double2* BigTab;
// Prototypes
void read(u32 WG, u32 N, T2 *u, const global T2 *in, u32 base);
void write(u32 WG, u32 N, T2 *u, global T2 *out, u32 base);
void bar(void);
OVERLOAD void bar();
OVERLOAD void bar(u32);

void read(u32 WG, u32 N, T2 *u, const global T2 *in, u32 base) {
in += base + (u32) get_local_id(0);
Expand All @@ -162,17 +163,6 @@ void write(u32 WG, u32 N, T2 *u, global T2 *out, u32 base) {

T2 U2(T a, T b) { return (T2) (a, b); }

void bar() {
// barrier(CLK_LOCAL_MEM_FENCE) is correct, but it turns out that on some GPUs
// (in particular on Radeon VII and Radeon PRO VII) barrier(0) works as well and is faster.
// So allow selecting the faster path when it works with -use FAST_BARRIER
#if FAST_BARRIER
barrier(0);
#else
barrier(CLK_LOCAL_MEM_FENCE);
#endif
}

// On "classic" AMD GCN GPUs such as Radeon VII, the wavefront size was always 64. On RDNA GPUs the wavefront can
// be configured to be either 64 or 32. We use the FAST_BARRIER define as an indicator for GCN GPUs.
// On Nvidia GPUs the wavefront size is 32.
Expand All @@ -183,3 +173,20 @@ void bar() {
#define WAVEFRONT 32
#endif
#endif

OVERLOAD void bar() {
// barrier(CLK_LOCAL_MEM_FENCE) is correct, but it turns out that on some GPUs
// (in particular on Radeon VII and Radeon PRO VII) barrier(0) works as well and is faster.
// So allow selecting the faster path when it works with -use FAST_BARRIER
#if FAST_BARRIER
barrier(0);
#else
barrier(CLK_LOCAL_MEM_FENCE);
#endif
}

OVERLOAD void bar(u32 WG) { if (WG > WAVEFRONT) { bar(); } }

// A half-barrier is only needed when half-a-workgroup needs a barrier.
// This is used e.g. by the double-wide tailSquare, where LDS is split between the halves.
void halfBar() { if (get_enqueued_local_size(0) / 2 > WAVEFRONT) { bar(); } }
40 changes: 9 additions & 31 deletions src/cl/fftbase.cl
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,6 @@ void shufl(u32 WG, local T2 *lds2, T2 *u, u32 n, u32 f) {
u32 me = get_local_id(0);
local T* lds = (local T*) lds2;

#if 0
// This also works for *f* that is not a power of two.
for (u32 i = 0; i < n; ++i) { lds[i * f + (me / f) * f * n + me % f] = u[i].x; }
bar();
for (u32 i = 0; i < n; ++i) { u[i].x = lds[i * WG + me]; }
bar();
for (u32 i = 0; i < n; ++i) { lds[i * f + (me / f) * f * n + me % f] = u[i].y; }
bar();
for (u32 i = 0; i < n; ++i) { u[i].y = lds[i * WG + me]; }
#else
u32 mask = f - 1;
assert((mask & (mask + 1)) == 0);
for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i].x; }
Expand All @@ -89,7 +79,6 @@ void shufl(u32 WG, local T2 *lds2, T2 *u, u32 n, u32 f) {
for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i].y; }
bar();
for (u32 i = 0; i < n; ++i) { u[i].y = lds[i * WG + me]; }
#endif
}

// Shufl two simultaneous FFT_HEIGHTs. Needed for tailSquared where u and v are computed simultaneously in different threads.
Expand All @@ -101,32 +90,21 @@ void shufl2(u32 WG, local T2 *lds2, T2 *u, u32 n, u32 f) {

// Partition lds memory into upper and lower halves
assert(WG == G_H);
lds2 += (me / WG) * SMALL_HEIGHT;

// Accessing lds memory as doubles is faster than T2 accesses
local T* lds = (local T*) lds2;
local T* lds = ((local T*) lds2) + (me / WG) * SMALL_HEIGHT;

me = me % WG;
#if 0
// This also works for *f* that is not a power of two.
for (u32 i = 0; i < n; ++i) {
u32 idx = i * f + (me / f) * f * n + me % f;
lds[idx] = u[i].x;
lds[SMALL_HEIGHT + idx] = u[i].y;
}
if (WG > WAVEFRONT) bar();
for (u32 i = 0; i < n; ++i) { u[i].x = lds[i * WG + me]; u[i].y = lds[SMALL_HEIGHT + i * WG + me]; }
#else
u32 mask = f - 1;
assert((mask & (mask + 1)) == 0);
for (u32 i = 0; i < n; ++i) {
u32 idx = i * f + (me & ~mask) * n + (me & mask);
lds[idx] = u[i].x;
lds[SMALL_HEIGHT + idx] = u[i].y;
}
if (WG > WAVEFRONT) bar();
for (u32 i = 0; i < n; ++i) { u[i].x = lds[i * WG + me]; u[i].y = lds[SMALL_HEIGHT + i * WG + me]; }
#endif

for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i].x; }
bar(WG);
for (u32 i = 0; i < n; ++i) { u[i].x = lds[i * WG + me]; }
bar(WG);
for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i].y; }
bar(WG);
for (u32 i = 0; i < n; ++i) { u[i].y = lds[i * WG + me]; }
}


Expand Down
4 changes: 2 additions & 2 deletions src/cl/fftheight.cl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ void fft_HEIGHT(local T2 *lds, T2 *u, Trig trig, T2 w) {
void fft_HEIGHT2(local T2 *lds, T2 *u, Trig trig, T2 w) {
u32 WG = SMALL_HEIGHT / NH;
for (u32 s = 1; s < SMALL_HEIGHT / NH; s *= NH) {
if (/* WG > WAVEFRONT && */ s > 1) { bar(); }
if (s > 1) { bar(WG); }
fft_NH(u);
w = bcast(w, s);

Expand Down Expand Up @@ -84,7 +84,7 @@ void fft_HEIGHT2(local T2 *lds, T2 *u, Trig trig, T2 w) {
#endif

for (u32 s = 1; s < WG; s *= NH) {
if (s > 1) { if (WG > WAVEFRONT) bar(); }
if (s > 1) { bar(WG); }
fft_NH(u);
tabMul(WG, trig, u, NH, s, me % WG);
shufl2(WG, lds, u, NH, s);
Expand Down
22 changes: 15 additions & 7 deletions src/cl/tailsquare.cl
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ void pairSq2_special(T2 *u, T2 base_squared) {
}

KERNEL(G_H * 2) tailSquare(P(T2) out, CP(T2) in, Trig smallTrig) {
local T2 lds[SMALL_HEIGHT*2]; // change reverse line to halve this
local T2 lds[SMALL_HEIGHT];

T2 u[NH];

Expand All @@ -263,7 +263,10 @@ KERNEL(G_H * 2) tailSquare(P(T2) out, CP(T2) in, Trig smallTrig) {
u32 me = get_local_id(0);
u32 lowMe = me % G_H; // lane-id in one of the two halves (half-workgroups).

u32 line = (me < G_H) ? line_u : line_v;
// It's not clear where up/down is, so we're going to call the halves "first-half" and "second-half".
bool isSecondHalf = me >= G_H;

u32 line = !isSecondHalf ? line_u : line_v;

// Read lines u and v
readTailFusedLine(in, u, line, lowMe);
Expand All @@ -278,12 +281,17 @@ KERNEL(G_H * 2) tailSquare(P(T2) out, CP(T2) in, Trig smallTrig) {

T2 trig = slowTrig_N(line + H * lowMe, ND / NH * 2);

revSwapLine(G_H, lds, u + NH/2, NH/2);
bar(G_H);

revCrossLine(G_H, lds, u + NH/2, NH/2, isSecondHalf);
pairSq(NH/2, u, u + NH/2, trig, false);
revSwapLine(G_H, lds, u + NH/2, NH/2);

// if (G_H > WAVEFRONT) bar();
bar();

bar(G_H);
// We change the LDS halves we're using in order to enable half-bars
revCrossLine(G_H, lds, u + NH/2, NH/2, !isSecondHalf);

bar(G_H);

fft_HEIGHT2(lds, u, smallTrig, w);

// Write lines u and v
Expand Down
17 changes: 7 additions & 10 deletions src/cl/tailutil.cl
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,18 @@ void reverseLine(u32 WG, local T2 *lds2, T2 *u) {
for (u32 i = 0; i < NH; ++i) { u[i] = lds[WG * i]; }
}

void revSwapLine(u32 WG, local T2* lds2, T2 *u, u32 n) {
// This is used to reverse the second part of a line, and cross the reversed parts between the halves.
void revCrossLine(u32 WG, local T2* lds2, T2 *u, u32 n, bool writeSecondHalf) {
u32 me = get_local_id(0);
u32 lowMe = me % WG;
bool upHalf = me >= WG;

u32 revLowMe = WG - 1 - lowMe;

// We are initially using LDS with the same half-discipline as the fft_HEIGHT() which precedes us,
// so only a half-bar is needed here.
// if (WG > WAVEFRONT) { bar(); }
bar();
for (u32 i = 0; i < n; ++i) { lds2[WG * n * writeSecondHalf + WG * (n - 1 - i) + revLowMe] = u[i]; }

u32 revLowMe = WG - 1 - lowMe;
bar(); // we need a full bar because we're crossing halves

for (u32 i = 0; i < n; ++i) { lds2[WG * n * upHalf + WG * (n - 1 - i) + revLowMe] = u[i]; }
bar(); // we need a full bar because we're going to swap halves
for (u32 i = 0; i < n; ++i) { u[i] = lds2[WG * n * !upHalf + WG * i + lowMe]; }
for (u32 i = 0; i < n; ++i) { u[i] = lds2[WG * n * !writeSecondHalf + WG * i + lowMe]; }
}

// computes 2*(a.x*b.x+a.y*b.y) + i*2*(a.x*b.y+a.y*b.x)
Expand Down

0 comments on commit 1ba0bbe

Please sign in to comment.