Skip to content

Commit

Permalink
Fix aarch
Browse files Browse the repository at this point in the history
  • Loading branch information
lilith committed Aug 22, 2024
1 parent 4a87b35 commit 16c0e05
Showing 1 changed file with 90 additions and 88 deletions.
178 changes: 90 additions & 88 deletions imageflow_core/src/graphics/transpose.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ unsafe fn transpose4x4_sse(A: *mut f32, B: *mut f32, lda: i32, ldb: i32) {
_mm_storeu_ps(&mut *B.offset((2 as i32 * ldb) as isize), row3);
_mm_storeu_ps(&mut *B.offset((3 as i32 * ldb) as isize), row4);
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse2")]
unsafe fn transpose4x4_sse2(A: *const u32, B: *mut u32, stride_a: usize, stride_b: usize) {
let row1: __m128i = _mm_loadu_si128(A as *const __m128i);
Expand Down Expand Up @@ -114,7 +115,6 @@ unsafe fn transpose_8x8_avx2(src: *const u32, dst: *mut u32, src_stride: usize,

#[target_feature(enable = "neon")]
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn transpose_8x8_neon(src: *const u32, dst: *mut u32, src_stride: usize, dst_stride: usize) {
// Load 8 rows of 8 32-bit integers each
let row0 = vld1q_u32(src as *const u32);
Expand All @@ -127,24 +127,24 @@ unsafe fn transpose_8x8_neon(src: *const u32, dst: *mut u32, src_stride: usize,
let row7 = vld1q_u32(src.add(src_stride * 7) as *const u32);

// Transpose 8x8 matrix
let (tmp0, tmp1) = vtrnq_u32(row0, row1);
let (tmp2, tmp3) = vtrnq_u32(row2, row3);
let (tmp4, tmp5) = vtrnq_u32(row4, row5);
let (tmp6, tmp7) = vtrnq_u32(row6, row7);

let (tmp8, tmp9) = vuzpq_u32(tmp0, tmp2);
let (tmp10, tmp11) = vuzpq_u32(tmp1, tmp3);
let (tmp12, tmp13) = vuzpq_u32(tmp4, tmp6);
let (tmp14, tmp15) = vuzpq_u32(tmp5, tmp7);

let result0 = vreinterpretq_u32_u64(vtrn1q_u64(vreinterpretq_u64_u32(tmp8), vreinterpretq_u64_u32(tmp12)));
let result1 = vreinterpretq_u32_u64(vtrn1q_u64(vreinterpretq_u64_u32(tmp9), vreinterpretq_u64_u32(tmp13)));
let result2 = vreinterpretq_u32_u64(vtrn1q_u64(vreinterpretq_u64_u32(tmp10), vreinterpretq_u64_u32(tmp14)));
let result3 = vreinterpretq_u32_u64(vtrn1q_u64(vreinterpretq_u64_u32(tmp11), vreinterpretq_u64_u32(tmp15)));
let result4 = vreinterpretq_u32_u64(vtrn2q_u64(vreinterpretq_u64_u32(tmp8), vreinterpretq_u64_u32(tmp12)));
let result5 = vreinterpretq_u32_u64(vtrn2q_u64(vreinterpretq_u64_u32(tmp9), vreinterpretq_u64_u32(tmp13)));
let result6 = vreinterpretq_u32_u64(vtrn2q_u64(vreinterpretq_u64_u32(tmp10), vreinterpretq_u64_u32(tmp14)));
let result7 = vreinterpretq_u32_u64(vtrn2q_u64(vreinterpretq_u64_u32(tmp11), vreinterpretq_u64_u32(tmp15)));
let tmp01 = vtrnq_u32(row0, row1);
let tmp23 = vtrnq_u32(row2, row3);
let tmp45 = vtrnq_u32(row4, row5);
let tmp67 = vtrnq_u32(row6, row7);

let tmp89 = vuzpq_u32(tmp01.0, tmp23.0);
let tmp1011 = vuzpq_u32(tmp01.1, tmp23.1);
let tmp1213 = vuzpq_u32(tmp45.0, tmp67.0);
let tmp1415 = vuzpq_u32(tmp45.1, tmp67.1);

let result0 = vreinterpretq_u32_u64(vtrn1q_u64(vreinterpretq_u64_u32(tmp89.0), vreinterpretq_u64_u32(tmp1213.0)));
let result1 = vreinterpretq_u32_u64(vtrn1q_u64(vreinterpretq_u64_u32(tmp89.1), vreinterpretq_u64_u32(tmp1213.1)));
let result2 = vreinterpretq_u32_u64(vtrn1q_u64(vreinterpretq_u64_u32(tmp1011.0), vreinterpretq_u64_u32(tmp1415.0)));
let result3 = vreinterpretq_u32_u64(vtrn1q_u64(vreinterpretq_u64_u32(tmp1011.1), vreinterpretq_u64_u32(tmp1415.1)));
let result4 = vreinterpretq_u32_u64(vtrn2q_u64(vreinterpretq_u64_u32(tmp89.0), vreinterpretq_u64_u32(tmp1213.0)));
let result5 = vreinterpretq_u32_u64(vtrn2q_u64(vreinterpretq_u64_u32(tmp89.1), vreinterpretq_u64_u32(tmp1213.1)));
let result6 = vreinterpretq_u32_u64(vtrn2q_u64(vreinterpretq_u64_u32(tmp1011.0), vreinterpretq_u64_u32(tmp1415.0)));
let result7 = vreinterpretq_u32_u64(vtrn2q_u64(vreinterpretq_u64_u32(tmp1011.1), vreinterpretq_u64_u32(tmp1415.1)));

// Store the transposed rows
vst1q_u32(dst as *mut u32, result0);
Expand All @@ -166,71 +166,71 @@ unsafe fn transpose4x4_generic(A: *mut f32, B: *mut f32, lda: i32, ldb: i32) {
}
}

#[inline]
unsafe fn transpose_block_4x4(
A: *mut f32,
B: *mut f32,
n: i32,
m: i32,
lda: i32,
ldb: i32,
block_size: i32,
) {
//#pragma omp parallel for collapse(2)
let mut i: i32 = 0 as i32;
while i < n {
let mut j: i32 = 0 as i32;
while j < m {
let max_i2: i32 = if i + block_size < n {
(i) + block_size
} else {
n
};
let max_j2: i32 = if j + block_size < m {
(j) + block_size
} else {
m
};
let mut i2: i32 = i;
while i2 < max_i2 {
let mut j2: i32 = j;
while j2 < max_j2 {
#[cfg(target_arch = "x86_64")]
{
transpose4x4_sse(
&mut *A.offset((i2 * lda + j2) as isize),
&mut *B.offset((j2 * ldb + i2) as isize),
lda,
ldb,
);
}
#[cfg(target_arch = "aarch64")]
{
transpose4x4_neon(
&mut *A.offset((i2 * lda + j2) as isize),
&mut *B.offset((j2 * ldb + i2) as isize),
lda,
ldb,
);
}
#[cfg(all(not(target_arch = "aarch64"), not(target_arch = "x86_64")))]
{
transpose4x4_generic(
&mut *A.offset((i2 * lda + j2) as isize),
&mut *B.offset((j2 * ldb + i2) as isize),
lda,
ldb,
);
}
j2 += 4 as i32
}
i2 += 4 as i32
}
j += block_size
}
i += block_size
}
}
// #[inline]
// unsafe fn transpose_block_4x4(
// A: *mut f32,
// B: *mut f32,
// n: i32,
// m: i32,
// lda: i32,
// ldb: i32,
// block_size: i32,
// ) {
// //#pragma omp parallel for collapse(2)
// let mut i: i32 = 0 as i32;
// while i < n {
// let mut j: i32 = 0 as i32;
// while j < m {
// let max_i2: i32 = if i + block_size < n {
// (i) + block_size
// } else {
// n
// };
// let max_j2: i32 = if j + block_size < m {
// (j) + block_size
// } else {
// m
// };
// let mut i2: i32 = i;
// while i2 < max_i2 {
// let mut j2: i32 = j;
// while j2 < max_j2 {
// #[cfg(target_arch = "x86_64")]
// {
// transpose4x4_sse(
// &mut *A.offset((i2 * lda + j2) as isize),
// &mut *B.offset((j2 * ldb + i2) as isize),
// lda,
// ldb,
// );
// }
// #[cfg(target_arch = "aarch64")]
// {
// transpose4x4_neon(
// &mut *A.offset((i2 * lda + j2) as isize),
// &mut *B.offset((j2 * ldb + i2) as isize),
// lda,
// ldb,
// );
// }
// #[cfg(all(not(target_arch = "aarch64"), not(target_arch = "x86_64")))]
// {
// transpose4x4_generic(
// &mut *A.offset((i2 * lda + j2) as isize),
// &mut *B.offset((j2 * ldb + i2) as isize),
// lda,
// ldb,
// );
// }
// j2 += 4 as i32
// }
// i2 += 4 as i32
// }
// j += block_size
// }
// i += block_size
// }
// }

// Generic transposition function for [u32] slices
pub fn transpose_u32_slices(
Expand Down Expand Up @@ -295,7 +295,7 @@ fn transpose_multiple_of_block_size_rectangle(
let use8x8simd = false;

#[cfg(target_arch = "x86_64")]
let use4x4simd = false;
let use4x4simd = true;

#[cfg(not(target_arch = "x86_64"))]
let use4x4simd = false;
Expand Down Expand Up @@ -330,9 +330,11 @@ fn transpose_multiple_of_block_size_rectangle(
}
}
} else if use4x4simd {
for y in (y_block..max_y).step_by(4) {
for x in (x_block..max_x).step_by(4) {
unsafe {
#[cfg(target_arch = "x86_64")]
unsafe {
for y in (y_block..max_y).step_by(4) {
for x in (x_block..max_x).step_by(4) {

transpose4x4_sse2(
src.as_ptr().add(y * src_stride + x),
dst.as_mut_ptr().add(x * dst_stride + y),
Expand Down

0 comments on commit 16c0e05

Please sign in to comment.