Skip to content

Commit

Permalink
asmgen: generate PIC
Browse files Browse the repository at this point in the history
  • Loading branch information
jschanck committed Jul 31, 2019
1 parent 822e1d7 commit ea3c296
Show file tree
Hide file tree
Showing 18 changed files with 161 additions and 124 deletions.
14 changes: 8 additions & 6 deletions avx2-hps2048509/asmgen/poly_r2_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ def mult_128x128(xy, x, y, t1, t2):
p("vinserti128 $1, %xmm{}, %ymm{}, %ymm{}".format(t2, t2, t2)) # move low of t2 to high

# TODO can we do this without masks?
p("vpand mask0011, %ymm{}, %ymm{}".format(t0, t0))
p("vpand mask0110, %ymm{}, %ymm{}".format(t1, t1))
p("vpand mask1100, %ymm{}, %ymm{}".format(t2, t2))
p("vpand mask0011(%rip), %ymm{}, %ymm{}".format(t0, t0))
p("vpand mask0110(%rip), %ymm{}, %ymm{}".format(t1, t1))
p("vpand mask1100(%rip), %ymm{}, %ymm{}".format(t2, t2))

p("vpxor %ymm{}, %ymm{}, %ymm{}".format(t0, t1, t1))
p("vpxor %ymm{}, %ymm{}, %ymm{}".format(t1, t2, xy))
Expand Down Expand Up @@ -52,6 +52,7 @@ def karatsuba_256x256(ab, a, b, t0, t1, t2, t3, t4):

if __name__ == '__main__':
p(".data")
p(".section .rodata")
p(".align 32")
p("mask1100:")
for i in [0]*8 + [65535]*8:
Expand All @@ -73,6 +74,7 @@ def karatsuba_256x256(ab, a, b, t0, t1, t2, t3, t4):
p(".word {}".format(i))

p(".text")
p(".hidden poly_R2_mul")
p(".global poly_R2_mul")
p(".att_syntax prefix")

Expand Down Expand Up @@ -124,8 +126,8 @@ def karatsuba_256x256(ab, a, b, t0, t1, t2, t3, t4):
# accounts for the fact that the polynomial only has 509 coefficients and not 512.
# (So x^509 instead of x^512 needs to be added to x^0)
for i, word in enumerate(w[:2]):
p("vpand mask1000, %ymm{}, %ymm{}".format(w[i+1], t2))
p("vpand mask0111, %ymm{}, %ymm{}".format(w[i+2], t1))
p("vpand mask1000(%rip), %ymm{}, %ymm{}".format(w[i+1], t2))
p("vpand mask0111(%rip), %ymm{}, %ymm{}".format(w[i+2], t1))
p("vpxor %ymm{}, %ymm{}, %ymm{}".format(t1, t2, t1))

p("vpsrlq ${}, %ymm{}, %ymm{}".format(61, t1, t1))
Expand All @@ -136,7 +138,7 @@ def karatsuba_256x256(ab, a, b, t0, t1, t2, t3, t4):
p("vpxor %ymm{}, %ymm{}, %ymm{}".format(t1, w[i], w[i]))

# Get rid of the last three bits (509 coefficients and not 512)
p("vpand low253, %ymm{}, %ymm{}".format(w1, w1))
p("vpand low253(%rip), %ymm{}, %ymm{}".format(w1, w1))

p("vmovdqa %ymm{}, {}(%rdi)".format(w0, 0))
p("vmovdqa %ymm{}, {}(%rdi)".format(w1, 32))
Expand Down
10 changes: 6 additions & 4 deletions avx2-hps2048509/asmgen/poly_rq_mul_x_minus_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

if __name__ == '__main__':
p(".data")
p(".section .rodata")
p(".align 32")

p("mask_mod2048:")
Expand All @@ -27,6 +28,7 @@
p(".byte 255")

p(".text")
p(".hidden poly_Rq_mul_x_minus_1")
p(".global poly_Rq_mul_x_minus_1")
p(".att_syntax prefix")

Expand All @@ -38,15 +40,15 @@
for i in range(ceil(509 / 16)-1, 0, -1):
p("vmovdqu {}(%rsi), %ymm{}".format((i*16 - 1) * 2, a_imin1))
p("vpsubw {}(%rsi), %ymm{}, %ymm{}".format(i * 32, a_imin1, t0))
p("vpand mask_mod2048, %ymm{}, %ymm{}".format(t0, t0))
p("vpand mask_mod2048(%rip), %ymm{}, %ymm{}".format(t0, t0))
p("vmovdqa %ymm{}, {}(%rdi)".format(t0, i*32))
if i == ceil(509 / 16)-1:
# a_imin1 now contains 495 to 510 inclusive;
# we need 509 for [0], which is at position 14
p("vextracti128 $1, %ymm{}, %xmm{}".format(a_imin1, t1))
p("vpshufb shuf_5_to_0_zerorest, %ymm{}, %ymm{}".format(t1, t1))
p("vpshufb shuf_5_to_0_zerorest(%rip), %ymm{}, %ymm{}".format(t1, t1))
p("vpsubw {}(%rsi), %ymm{}, %ymm{}".format(0, t1, t1))
p("vpand mask_mod2048_only_lowest, %ymm{}, %ymm{}".format(t1, t1))
p("vpand mask_mod2048_only_lowest(%rip), %ymm{}, %ymm{}".format(t1, t1))

# and now we still need to fix [1] to [15], which we cannot vmovdqu
t2 = 0
Expand All @@ -58,7 +60,7 @@
p("vpsllq $16, %ymm{}, %ymm{}".format(t4, t3))
p("vpxor %ymm{}, %ymm{}, %ymm{}".format(t2, t3, t3))
p("vpsubw %ymm{}, %ymm{}, %ymm{}".format(t4, t3, t4))
p("vpand mask_mod2048_omit_lowest, %ymm{}, %ymm{}".format(t4, t4))
p("vpand mask_mod2048_omit_lowest(%rip), %ymm{}, %ymm{}".format(t4, t4))
p("vpxor %ymm{}, %ymm{}, %ymm{}".format(t4, t1, t4))
p("vmovdqa %ymm{}, {}(%rdi)".format(t4, 0))

Expand Down
14 changes: 8 additions & 6 deletions avx2-hps2048509/asmgen/poly_rq_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,23 @@
def mod3(a, r=13, t=14, c=15):
# r = (a >> 8) + (a & 0xff); // r mod 255 == a mod 255
p("vpsrlw $8, %ymm{}, %ymm{}".format(a, r))
p("vpand mask_ff, %ymm{}, %ymm{}".format(a, a))
p("vpand mask_ff(%rip), %ymm{}, %ymm{}".format(a, a))
p("vpaddw %ymm{}, %ymm{}, %ymm{}".format(r, a, r))

# r = (r >> 4) + (r & 0xf); // r' mod 15 == r mod 15
p("vpand mask_f, %ymm{}, %ymm{}".format(r, a))
p("vpand mask_f(%rip), %ymm{}, %ymm{}".format(r, a))
p("vpsrlw $4, %ymm{}, %ymm{}".format(r, r))
p("vpaddw %ymm{}, %ymm{}, %ymm{}".format(r, a, r))

# r = (r >> 2) + (r & 0x3); // r' mod 3 == r mod 3
# r = (r >> 2) + (r & 0x3); // r' mod 3 == r mod 3
for _ in range(2):
p("vpand mask_3, %ymm{}, %ymm{}".format(r, a))
p("vpand mask_3(%rip), %ymm{}, %ymm{}".format(r, a))
p("vpsrlw $2, %ymm{}, %ymm{}".format(r, r))
p("vpaddw %ymm{}, %ymm{}, %ymm{}".format(r, a, r))

# t = r - 3;
p("vpsubw mask_3, %ymm{}, %ymm{}".format(r, t))
p("vpsubw mask_3(%rip), %ymm{}, %ymm{}".format(r, t))
# c = t >> 15; t is signed, so shift arithmetic
p("vpsraw $15, %ymm{}, %ymm{}".format(t, c))

Expand All @@ -36,6 +36,7 @@ def mod3(a, r=13, t=14, c=15):

if __name__ == '__main__':
p(".data")
p(".section .rodata")
p(".align 32")

p("const_3_repeating:")
Expand All @@ -59,6 +60,7 @@ def mod3(a, r=13, t=14, c=15):
p(".word 0x03")

p(".text")
p(".hidden poly_Rq_to_S3")
p(".global poly_Rq_to_S3")
p(".att_syntax prefix")

Expand All @@ -69,7 +71,7 @@ def mod3(a, r=13, t=14, c=15):
threes = 3
last = 4
retval = 5
p("vmovdqa const_3_repeating, %ymm{}".format(threes))
p("vmovdqa const_3_repeating(%rip), %ymm{}".format(threes))
p("vmovdqa {}(%rsi), %ymm{}".format((ceil(509 / 16) - 1)*32, last))

p("vpsrlw $10, %ymm{}, %ymm{}".format(last, r))
Expand All @@ -80,7 +82,7 @@ def mod3(a, r=13, t=14, c=15):
mod3(last, retval)
p("vpsllw $1, %ymm{}, %ymm{}".format(retval, last))
p("vextracti128 $1, %ymm{}, %xmm{}".format(last, last))
p("vpshufb shuf_b8_to_low_doubleword, %ymm{}, %ymm{}".format(last, last))
p("vpshufb shuf_b8_to_low_doubleword(%rip), %ymm{}, %ymm{}".format(last, last))
p("vinserti128 $1, %xmm{}, %ymm{}, %ymm{}".format(last, last, last))

for i in range(ceil(509 / 16)):
Expand Down
10 changes: 6 additions & 4 deletions avx2-hps2048509/asmgen/poly_s3_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,23 @@
def mod3(a, r=13, t=14, c=15):
# r = (a >> 8) + (a & 0xff); // r mod 255 == a mod 255
p("vpsrlw $8, %ymm{}, %ymm{}".format(a, r))
p("vpand mask_ff, %ymm{}, %ymm{}".format(a, a))
p("vpand mask_ff(%rip), %ymm{}, %ymm{}".format(a, a))
p("vpaddw %ymm{}, %ymm{}, %ymm{}".format(r, a, r))

# r = (r >> 4) + (r & 0xf); // r' mod 15 == r mod 15
p("vpand mask_f, %ymm{}, %ymm{}".format(r, a))
p("vpand mask_f(%rip), %ymm{}, %ymm{}".format(r, a))
p("vpsrlw $4, %ymm{}, %ymm{}".format(r, r))
p("vpaddw %ymm{}, %ymm{}, %ymm{}".format(r, a, r))

# r = (r >> 2) + (r & 0x3); // r' mod 3 == r mod 3
# r = (r >> 2) + (r & 0x3); // r' mod 3 == r mod 3
for _ in range(2):
p("vpand mask_3, %ymm{}, %ymm{}".format(r, a))
p("vpand mask_3(%rip), %ymm{}, %ymm{}".format(r, a))
p("vpsrlw $2, %ymm{}, %ymm{}".format(r, r))
p("vpaddw %ymm{}, %ymm{}, %ymm{}".format(r, a, r))

# t = r - 3;
p("vpsubw mask_3, %ymm{}, %ymm{}".format(r, t))
p("vpsubw mask_3(%rip), %ymm{}, %ymm{}".format(r, t))
# c = t >> 15; t is signed, so shift arithmetic
p("vpsraw $15, %ymm{}, %ymm{}".format(t, c))

Expand All @@ -39,6 +39,7 @@ def mod3(a, r=13, t=14, c=15):
# TODO: If poly_S3_mul shows up (significantly) in profiling consider the alternative implementation.
if __name__ == '__main__':
p(".data")
p(".section .rodata")
p(".align 32")

p("mask_ff:")
Expand All @@ -52,6 +53,7 @@ def mod3(a, r=13, t=14, c=15):
p(".word 0x03")

p(".text")
p(".hidden poly_S3_mul")
p(".global poly_S3_mul")
p(".att_syntax prefix")

Expand Down
54 changes: 28 additions & 26 deletions avx2-hps2048509/asmgen/rq_mul/poly_rq_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def idx2off(i):

if __name__ == '__main__':
p(".data")
p(".section .rodata")
p(".align 32")

p("mask_low13words:")
Expand Down Expand Up @@ -172,6 +173,7 @@ def idx2off(i):
p(".word 2047")

p(".text")
p(".hidden poly_Rq_mul")
p(".global poly_Rq_mul")
p(".att_syntax prefix")

Expand Down Expand Up @@ -211,7 +213,7 @@ def idx2off(i):
# we evaluate for first 16 coefficients of each block, then 16 remaining coefficients

const_3 = 3
p("vmovdqa const3, %ymm{}".format(const_3))
p("vmovdqa const3(%rip), %ymm{}".format(const_3))

for (prep, real) in [(a_prep, a_real), (b_prep, b_real)]:
for coeff in range(2):
Expand All @@ -226,7 +228,7 @@ def idx2off(i):
# there are 509 coefficients, not 512;
# mask out the final 3 (because of 512 mod 509)
if coeff == 1:
p("vpand mask_low13words, %ymm{}, %ymm{}".format(f3[3], f3[3]))
p("vpand mask_low13words(%rip), %ymm{}, %ymm{}".format(f3[3], f3[3]))

# retrieve f1 so we can store it on the stack and use for vpadd
f1 = [8, 9, 10, 11]
Expand Down Expand Up @@ -340,13 +342,13 @@ def alloc():
return registers.pop()

const729 = alloc()
p("vmovdqa const729, %ymm{}".format(const729))
p("vmovdqa const729(%rip), %ymm{}".format(const729))
const3_inv = alloc()
p("vmovdqa const3_inv, %ymm{}".format(const3_inv))
p("vmovdqa const3_inv(%rip), %ymm{}".format(const3_inv))
const5_inv = alloc()
p("vmovdqa const5_inv, %ymm{}".format(const5_inv))
p("vmovdqa const5_inv(%rip), %ymm{}".format(const5_inv))
const9 = alloc()
p("vmovdqa const9, %ymm{}".format(const9))
p("vmovdqa const9(%rip), %ymm{}".format(const9))

# consider swapping this around for more closely linked memory access
# they're somewhat spread around because of how the transpose worked, but
Expand All @@ -367,8 +369,8 @@ def limb(i):
p("vmovdqa {}, %ymm{}".format(limb(0), h0))
h0lo = alloc()
h0hi = alloc()
p("vpunpcklwd const0, %ymm{}, %ymm{}".format(h0, h0lo))
p("vpunpckhwd const0, %ymm{}, %ymm{}".format(h0, h0hi))
p("vpunpcklwd const0(%rip), %ymm{}, %ymm{}".format(h0, h0lo))
p("vpunpckhwd const0(%rip), %ymm{}, %ymm{}".format(h0, h0hi))
free(h0lo)
h0_2lo = alloc()
p("vpslld $1, %ymm{}, %ymm{}".format(h0lo, h0_2lo))
Expand All @@ -379,18 +381,18 @@ def limb(i):
t1 = alloc()
p("vmovdqa {}, %ymm{}".format(limb(1), t1))
t1lo = alloc()
p("vpunpcklwd const0, %ymm{}, %ymm{}".format(t1, t1lo))
p("vpunpcklwd const0(%rip), %ymm{}, %ymm{}".format(t1, t1lo))
free(t1)
t1hi = alloc()
p("vpunpckhwd const0, %ymm{}, %ymm{}".format(t1, t1hi))
p("vpunpckhwd const0(%rip), %ymm{}, %ymm{}".format(t1, t1hi))

t2 = alloc()
p("vmovdqa {}, %ymm{}".format(limb(2), t2))
t2lo = alloc()
p("vpunpcklwd const0, %ymm{}, %ymm{}".format(t2, t2lo))
p("vpunpcklwd const0(%rip), %ymm{}, %ymm{}".format(t2, t2lo))
free(t2)
t2hi = alloc()
p("vpunpckhwd const0, %ymm{}, %ymm{}".format(t2, t2hi))
p("vpunpckhwd const0(%rip), %ymm{}, %ymm{}".format(t2, t2hi))

t11lo = alloc()
p("vpaddd %ymm{}, %ymm{}, %ymm{}".format(t2lo, t1lo, t11lo))
Expand All @@ -413,18 +415,18 @@ def limb(i):
p("vpsrld $1, %ymm{}, %ymm{}".format(t12lo, t12lo))
p("vpsrld $1, %ymm{}, %ymm{}".format(t12hi, t12hi))

p("vpand mask32_to_16, %ymm{}, %ymm{}".format(t12lo, t12lo))
p("vpand mask32_to_16, %ymm{}, %ymm{}".format(t12hi, t12hi))
p("vpand mask32_to_16(%rip), %ymm{}, %ymm{}".format(t12lo, t12lo))
p("vpand mask32_to_16(%rip), %ymm{}, %ymm{}".format(t12hi, t12hi))
free(t12lo, t12hi)
r11s = alloc()
p("vpackusdw %ymm{}, %ymm{}, %ymm{}".format(t12hi, t12lo, r11s))

h6 = alloc()
p("vmovdqa {}, %ymm{}".format(limb(6), h6))
h6lo = alloc()
p("vpunpcklwd const0, %ymm{}, %ymm{}".format(h6, h6lo))
p("vpunpcklwd const0(%rip), %ymm{}, %ymm{}".format(h6, h6lo))
h6hi = alloc()
p("vpunpckhwd const0, %ymm{}, %ymm{}".format(h6, h6hi))
p("vpunpckhwd const0(%rip), %ymm{}, %ymm{}".format(h6, h6hi))
free(h6lo)
h6_2lo = alloc()
p("vpslld $1, %ymm{}, %ymm{}".format(h6lo, h6_2lo))
Expand All @@ -442,8 +444,8 @@ def limb(i):
p("vpsrld $1, %ymm{}, %ymm{}".format(t11c2lo, t11c2lo))
p("vpsrld $1, %ymm{}, %ymm{}".format(t11c2hi, t11c2hi))

p("vpand mask32_to_16, %ymm{}, %ymm{}".format(t11c2lo, t11c2lo))
p("vpand mask32_to_16, %ymm{}, %ymm{}".format(t11c2hi, t11c2hi))
p("vpand mask32_to_16(%rip), %ymm{}, %ymm{}".format(t11c2lo, t11c2lo))
p("vpand mask32_to_16(%rip), %ymm{}, %ymm{}".format(t11c2hi, t11c2hi))
free(t11c2lo, t11c2hi)
r11 = alloc()
p("vpackusdw %ymm{}, %ymm{}, %ymm{}".format(t11c2hi, t11c2lo, r11))
Expand Down Expand Up @@ -567,12 +569,12 @@ def get_limb(limbreg, i, j):
def store_limb(limbreg, i, j):
if i == 3 and j >= 4: # this part exceeds 512
return
p("vpand mask_mod2048, %ymm{}, %ymm{}".format(limbreg, limbreg))
p("vpand mask_mod2048(%rip), %ymm{}, %ymm{}".format(limbreg, limbreg))
p("vmovdqa %ymm{}, {}({})".format(limbreg, (i*128 + j * 32 + coeff*16) * 2, r_real))

if coeff == 1 and j == 3: # these are bits 509 to 512, which we must spill into stack
p("vextracti128 $1, %ymm{}, %xmm{}".format(limbreg, limbreg, limbreg))
p("vpshufb shufmin5_mask3, %ymm{}, %ymm{}".format(limbreg, limbreg))
p("vpshufb shufmin5_mask3(%rip), %ymm{}, %ymm{}".format(limbreg, limbreg))
p("vmovdqa %xmm{}, {}(%rsp)".format(limbreg, (compose_offset+0*8+j-(3-i))*32))

# these exceptional cases have bits overflowing into two limbs over;
Expand All @@ -581,7 +583,7 @@ def store_limb(limbreg, i, j):
for i in [2, 3, 4]:
tmp = alloc()
p("vextracti128 $1, %ymm{}, %xmm{}".format(h[i], tmp))
p("vpshufb shufmin5_mask3, %ymm{}, %ymm{}".format(tmp, tmp))
p("vpshufb shufmin5_mask3(%rip), %ymm{}, %ymm{}".format(tmp, tmp))
p("vmovdqa %ymm{}, {}(%rsp)".format(tmp, (far_spill_offset+i-2)*32))
free(tmp)

Expand All @@ -606,14 +608,14 @@ def store_limb(limbreg, i, j):
temp = alloc()
temp2 = alloc()
# rotate by 3 words in each lane
p("vpshufb shuf48_16, %ymm{}, %ymm{}".format(h[i+4], h[i+4]))
p("vpand {}, %ymm{}, %ymm{}".format('mask3_5_3_5', h[i+4], temp))
p("vpshufb shuf48_16(%rip), %ymm{}, %ymm{}".format(h[i+4], h[i+4]))
p("vpand mask3_5_3_5(%rip), %ymm{}, %ymm{}".format(h[i+4], temp))
# clear the 2x 3 words so that they can be added in later
p("vpand mask5_3_5_3, %ymm{}, %ymm{}".format(h[i+4], h[i+4]))
p("vpand mask5_3_5_3(%rip), %ymm{}, %ymm{}".format(h[i+4], h[i+4]))
# grab the 3 words and put into position for adding them in
p("vpermq ${}, %ymm{}, %ymm{}".format(int('11001110', 2), temp, temp))
# add in the 3 low words that stay within this 16-word chunk
p("vpand mask_keephigh, %ymm{}, %ymm{}".format(temp, temp2))
p("vpand mask_keephigh(%rip), %ymm{}, %ymm{}".format(temp, temp2))
p("vpor %ymm{}, %ymm{}, %ymm{}".format(temp2, h[i+4], h[i+4]))
free(temp2)
# if it's h3, we cannot add to another high limb
Expand Down Expand Up @@ -653,7 +655,7 @@ def store_limb(limbreg, i, j):
# exception case for two coefficients flowing from h2 into h0, h3 into h1, h4 into h2
if j == 0 and i in [0, 1, 2]:
p("vpaddw {}(%rsp), %ymm{}, %ymm{}".format((far_spill_offset+i)*32, htemp, htemp))
p("vpand mask_mod2048, %ymm{}, %ymm{}".format(htemp, htemp))
p("vpand mask_mod2048(%rip), %ymm{}, %ymm{}".format(htemp, htemp))
p("vmovdqa %ymm{}, {}({})".format(htemp, (i*128 + j * 32 + coeff*16) * 2, r_real))
free(htemp)

Expand Down
6 changes: 3 additions & 3 deletions avx2-hps2048509/bitpermutations/bitpermutations/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,10 @@ def __init__(self, value, size=256):
DATASECTION.append(self)

def __str__(self):
return "mask_{}_{}".format(self.maskindex, SALT)
return "mask_{}_{}(%rip)".format(self.maskindex, SALT)

def data(self):
output = "{}:\n".format(str(self))
output = "mask_{}_{}:\n".format(self.maskindex, SALT)
if self.size % 16 != 0:
raise NotImplementedError("Can only divide masks into words")
# TODO this can be optimized by dividing into order sizes
Expand Down Expand Up @@ -260,7 +260,7 @@ def __len__(self):
return self.size // 8

def data(self):
output = "{}:\n".format(str(self))
output = "mask_{}_{}:\n".format(self.maskindex, SALT)
# TODO we still assume bytewise indices
for i in self.indices:
if i is None:
Expand Down
Loading

0 comments on commit ea3c296

Please sign in to comment.