Skip to content

Commit

Permalink
Double-wide tailSquare
Browse files Browse the repository at this point in the history
  • Loading branch information
gwoltman authored and preda committed Dec 9, 2024
1 parent 01161f8 commit 6a5f5e7
Show file tree
Hide file tree
Showing 8 changed files with 473 additions and 27 deletions.
19 changes: 18 additions & 1 deletion src/Gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,18 @@ Gpu::Gpu(Queue* q, GpuCommon shared, FFTConfig fft, u32 E, const vector<KeyVal>&
// SMALL_H / nH
K(fftHin, "ffthin.cl", "fftHin", hN / nH),

#if DOUBLE_WIDE
// Two double-wide kernels
K(tailSquareOne, "tailsquare.cl", "tailSquareOne", SMALL_H / nH * 2),
K(tailSquare, "tailsquare.cl", "tailSquare", hN / nH - SMALL_H / nH * 2),
#elif DOUBLE_WIDE_ONEK
// One double-wide kernel
K(tailSquare, "tailsquare.cl", "tailSquare", hN / nH),
#else
// Old-style single-wide kernel
K(tailSquare, "tailsquare.cl", "tailSquare", hN / nH / 2),
#endif

K(tailMul, "tailmul.cl", "tailMul", hN / nH / 2),
K(tailMulLow, "tailmul.cl", "tailMul", hN / nH / 2, "-DMUL_LOW=1"),

Expand Down Expand Up @@ -513,6 +524,9 @@ Gpu::Gpu(Queue* q, GpuCommon shared, FFTConfig fft, u32 E, const vector<KeyVal>&
carryB.setFixedArgs(1, bufCarry, bufBitsC);
tailMulLow.setFixedArgs(3, bufTrigH);
tailMul.setFixedArgs(3, bufTrigH);
#if DOUBLE_WIDE
tailSquareOne.setFixedArgs(2, bufTrigH);
#endif
tailSquare.setFixedArgs(2, bufTrigH);
kernIsEqual.setFixedArgs(2, bufTrue);

Expand Down Expand Up @@ -815,6 +829,9 @@ static bool testBit(u64 x, int bit) { return x & (u64(1) << bit); }

void Gpu::bottomHalf(Buffer<double>& out, Buffer<double>& inTmp) {
fftMidIn(out, inTmp);
#if DOUBLE_WIDE
tailSquareOne(inTmp, out);
#endif
tailSquare(inTmp, out);
fftMidOut(out, inTmp);
}
Expand Down Expand Up @@ -1360,7 +1377,7 @@ PRPResult Gpu::isPrimePRP(const Task& task) {
}

assert(blockSize > 0 && LOG_STEP % blockSize == 0);

u32 checkStep = checkStepForErrors(blockSize, nErrors);
assert(checkStep % LOG_STEP == 0);

Expand Down
9 changes: 9 additions & 0 deletions src/Gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@
#include <filesystem>
#include <cmath>

// Klunky defines for single-wide vs. double-wide tailSquare
// Clean this up once we determine which options to make user visible
#define SINGLE_WIDE 0 // Old single-wide tailSquare
#define DOUBLE_WIDE_ONEK 0 // New single-wide tailSquare in a single kernel
#define DOUBLE_WIDE 1 // New single-wide tailSquare in two kernels

struct PRPResult;
struct Task;

Expand Down Expand Up @@ -130,6 +136,9 @@ class Gpu {

Kernel fftHin;

#if DOUBLE_WIDE
Kernel tailSquareOne;
#endif
Kernel tailSquare;
Kernel tailMul;
Kernel tailMulLow;
Expand Down
113 changes: 112 additions & 1 deletion src/cl/fftbase.cl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ void shufl(u32 WG, local T2 *lds2, T2 *u, u32 n, u32 f) {
u32 me = get_local_id(0);
local T* lds = (local T*) lds2;


#if 0
// This also works for *f* that is not a power of two.
for (u32 i = 0; i < n; ++i) { lds[i * f + (me / f) * f * n + me % f] = u[i].x; }
Expand All @@ -93,6 +92,44 @@ void shufl(u32 WG, local T2 *lds2, T2 *u, u32 n, u32 f) {
#endif
}

// Shufl two simultaneous FFT_HEIGHTs. Needed for tailSquared where u and v are computed simultaneously in different threads.
// NOTE: It is very important for this routine to use lds memory in coordination with reverseLine2 and unreverseLine2.
// Failure to do so would result in the need for more bar() calls. Specifically, the u values are stored in the upper half
// of lds memory (first SMALL_HEIGHT T2 values). The v values are stored in the lower half of lds memory (next SMALL_HEIGHT T2 values).
void shufl2(u32 WG, local T2 *lds2, T2 *u, u32 n, u32 f) {
u32 me = get_local_id(0);

// Partition lds memory into upper and lower halves
assert(WG == G_H);
lds2 += (me / WG) * SMALL_HEIGHT;

// Accessing lds memory as doubles is faster than T2 accesses
local T* lds = (local T*) lds2;

me = me % WG;
#if 0
// This also works for *f* that is not a power of two.
for (u32 i = 0; i < n; ++i) {
u32 idx = i * f + (me / f) * f * n + me % f;
lds[idx] = u[i].x;
lds[SMALL_HEIGHT + idx] = u[i].y;
}
if (WG > WAVEFRONT) bar();
for (u32 i = 0; i < n; ++i) { u[i].x = lds[i * WG + me]; u[i].y = lds[SMALL_HEIGHT + i * WG + me]; }
#else
u32 mask = f - 1;
assert((mask & (mask + 1)) == 0);
for (u32 i = 0; i < n; ++i) {
u32 idx = i * f + (me & ~mask) * n + (me & mask);
lds[idx] = u[i].x;
lds[SMALL_HEIGHT + idx] = u[i].y;
}
if (WG > WAVEFRONT) bar();
for (u32 i = 0; i < n; ++i) { u[i].x = lds[i * WG + me]; u[i].y = lds[SMALL_HEIGHT + i * WG + me]; }
#endif
}


void tabMul(u32 WG, Trig trig, T2 *u, u32 n, u32 f) {
u32 me = get_local_id(0);

Expand Down Expand Up @@ -160,3 +197,77 @@ void tabMul(u32 WG, Trig trig, T2 *u, u32 n, u32 f) {
#error CLEAN must be 0 or 1
#endif
}

// Tabmul two simultaneous FFT_HEIGHTs. Needed for the new tailSquared where u and v are computed simultaneously in different threads.

void tabMul2(u32 WG, Trig trig, T2 *u, u32 n, u32 f) {
u32 me = get_local_id(0);

me = me % WG;
#if 0
u32 p = me / f * f;
#else
u32 p = me & ~(f - 1);
#endif

#if 0
T2 w = slowTrig_N(ND / n / WG * p, ND / n);
T2 base = w;
for (int i = 1; i < n; ++i) {
u[i] = cmul(u[i], w);
w = cmul(w, base);
}
#endif

T2 w = trig[p];

if (n >= 8) {
u[1] = cmulFancy(u[1], w);
} else {
u[1] = cmul(u[1], w);
}

#if 0//CLEAN == 2 // Titan V likes this case -- only one read. But Z not helped much. Needs more investigation.

u32 midpt = (n + 1) / 2;
T2 base = trig[(midpt-1)*WG + p];
T2 tmp = cmulFancyDual_setup(base, w);
u[midpt] = cmul(u[midpt], base);
T2 base1 = cmulFancyDual_conj(base, w, tmp);
u[midpt - 1] = cmul(u[midpt - 1], base1);
T2 base2 = cmulFancyDual_plain(base, w, tmp);
u[midpt + 1] = cmul(u[midpt + 1], base2);
for (u32 i = 2; midpt + i < MIDDLE; ++i) {
if (midpt - i > 1) u[midpt - i] = cmul(u[midpt - i], base1 = cmulFancy(base1, conjugate(w)));
u[midpt + i] = cmul(u[midpt + i], base2 = cmulFancy(base2, w));
}

#elif CLEAN == 1 // Radeon VII loves this case

for (u32 i = 2; i < n; ++i) {
T2 base = trig[(i-1)*WG + p];
u[i] = cmul(u[i], base);
}

#elif CLEAN == 0
if (n >= 8) {
T2 base = csqTrigFancyFancy(w);
u[2] = cmulFancy(u[2], base);
base = ccubeTrigFancy(base, w);
for (u32 i = 3; i < n; ++i) {
u[i] = cmul(u[i], base);
base = cmulFancy(base, w);
}
} else {
T2 base = csqTrig(w);
u[2] = cmul(u[2], base);
base = ccubeTrig(base, w);
for (u32 i = 3; i < n; ++i) {
u[i] = cmul(u[i], base);
base = cmul(base, w);
}
}
#else
#error CLEAN must be 0 or 1
#endif
}
15 changes: 15 additions & 0 deletions src/cl/fftheight.cl
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,19 @@ void fft_HEIGHT(local T2 *lds, T2 *u, Trig trig, T2 w) {
fft_NH(u);
}

void fft_HEIGHT2(local T2 *lds, T2 *u, Trig trig, T2 w) {

#if !UNROLL_H
__attribute__((opencl_unroll_hint(1)))
#endif

for (u32 s = 1; s < SMALL_HEIGHT / NH; s *= NH) {
if (s > 1) { if (SMALL_HEIGHT / NH > WAVEFRONT) bar(); }
fft_NH(u);
tabMul2(SMALL_HEIGHT / NH, trig, u, NH, s);
shufl2(SMALL_HEIGHT / NH, lds, u, NH, s);
}
fft_NH(u);
}

#endif
11 changes: 5 additions & 6 deletions src/cl/middle.cl
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,7 @@ void writeMiddleInLine (P(T2) out, T2 *u, u32 chunk_y, u32 chunk_x)

// Read a line for tailFused or fftHin
// This reads partially transposed data as written by fftMiddleIn
void readTailFusedLine(CP(T2) in, T2 *u, u32 line) {
u32 me = get_local_id(0);
void readTailFusedLine(CP(T2) in, T2 *u, u32 line, u32 me) {
u32 SIZEY = IN_WG / IN_SIZEX;

#if PADDING
Expand Down Expand Up @@ -199,17 +198,17 @@ void readTailFusedLine(CP(T2) in, T2 *u, u32 line) {
// i in u[i] ranges 0...MIDDLE-1 (multiples of SMALL_HEIGHT)
// y ranges 0...WIDTH-1 (multiples of BIG_HEIGHT) (processed in batches of OUT_WG/OUT_SIZEX)

void writeTailFusedLine(T2 *u, P(T2) out, u32 line) {
void writeTailFusedLine(T2 *u, P(T2) out, u32 line, u32 me) {
#if PADDING
#if MIDDLE == 4 || MIDDLE == 8 || MIDDLE == 16
u32 BIG_PAD_SIZE = (PAD_SIZE/2+1)*PAD_SIZE;
out += line * (SMALL_HEIGHT + PAD_SIZE) + line / MIDDLE * BIG_PAD_SIZE + (u32) get_local_id(0); // Pad every output line plus every MIDDLE
out += line * (SMALL_HEIGHT + PAD_SIZE) + line / MIDDLE * BIG_PAD_SIZE + me; // Pad every output line plus every MIDDLE
#else
out += line * (SMALL_HEIGHT + PAD_SIZE) + (u32) get_local_id(0); // Pad every output line
out += line * (SMALL_HEIGHT + PAD_SIZE) + me; // Pad every output line
#endif
for (u32 i = 0; i < NH; ++i) { out[i * G_H] = u[i]; }
#else // No padding, might be better on nVidia cards
out += line * SMALL_HEIGHT + (u32) get_local_id(0);
out += line * SMALL_HEIGHT + me;
for (u32 i = 0; i < NH; ++i) { out[i * G_H] = u[i]; }
#endif
}
Expand Down
14 changes: 7 additions & 7 deletions src/cl/tailmul.cl
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ KERNEL(G_H) tailMul(P(T2) out, CP(T2) in, CP(T2) a, Trig smallTrig) {
u32 line2 = line1 ? H - line1 : (H / 2);
u32 memline1 = transPos(line1, MIDDLE, WIDTH);
u32 memline2 = transPos(line2, MIDDLE, WIDTH);

readTailFusedLine(in, u, line1);
readTailFusedLine(in, v, line2);

u32 me = get_local_id(0);
readTailFusedLine(in, u, line1, me);
readTailFusedLine(in, v, line2, me);

#if NH == 8
T2 w = fancyTrig_N(ND / SMALL_HEIGHT * me);
#else
Expand All @@ -87,8 +87,8 @@ KERNEL(G_H) tailMul(P(T2) out, CP(T2) in, CP(T2) a, Trig smallTrig) {
bar();
fft_HEIGHT(lds, v, smallTrig, w);
#else
readTailFusedLine(a, p, line1);
readTailFusedLine(a, q, line2);
readTailFusedLine(a, p, line1, me);
readTailFusedLine(a, q, line2, me);
fft_HEIGHT(lds, u, smallTrig, w);
bar();
fft_HEIGHT(lds, v, smallTrig, w);
Expand Down Expand Up @@ -122,6 +122,6 @@ KERNEL(G_H) tailMul(P(T2) out, CP(T2) in, CP(T2) a, Trig smallTrig) {
fft_HEIGHT(lds, v, smallTrig, w);
bar();
fft_HEIGHT(lds, u, smallTrig, w);
writeTailFusedLine(v, out, memline2);
writeTailFusedLine(u, out, memline1);
writeTailFusedLine(v, out, memline2, me);
writeTailFusedLine(u, out, memline1, me);
}
Loading

0 comments on commit 6a5f5e7

Please sign in to comment.