Skip to content

Commit

Permalink
Vectorized CPU code implementing left shift operator. (pytorch#88607)
Browse files Browse the repository at this point in the history
This PR adds vectorized implementation for CPU version of left shift operator.

All of the tests run by `pytest test/test_ops.py -vk left_shift` pass.

Here are some additional details:

<details>
<summary>
Benchmarking script (writen by Philip, with small tweaks by Mario) comparing left shifts with multiplications - on par now
</summary>

```python
import torch
from torch import Tensor
from torch.utils.benchmark import Timer, Compare
from itertools import product
from functools import partial

# These functions exist, because torch.jit.script does not support `torch.iinfo`
def _num_value_bits(dtype):
    if dtype == torch.uint8:
        return 8
    else:  # torch.int32
        return 31

def _max_value(dtype):
    if dtype == torch.uint8:
        return 255
    else:  # torch.int32
        return 2147483647

def bitshift(image, dtype):
    num_value_bits_input = _num_value_bits(image.dtype)
    num_value_bits_output = _num_value_bits(dtype)

    return image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)

def mul(image, dtype):
    input_max = float(_max_value(image.dtype))
    output_max = float(_max_value(dtype))

    factor = int((output_max + 1) // (input_max + 1))
    image = image.to(dtype)
    return image * factor

size = 256
image = torch.randint(0, 256, (3, size, size), dtype=torch.uint8)
dtype = torch.int32

def gen_inputs():
    devices = ("cpu",)
    fns = (mul, bitshift)
    threads = (1,)
    for device, fn, threads in product(devices, fns, threads):
        yield f"Bitshift {device} {image.dtype}", str(tuple(image.shape)), threads, fn, image, dtype

def benchmark(label, sub_label, threads, f, *args, **kwargs):
    return Timer("f(*args, **kwargs)",
                 globals=locals(),
                 label=label,
                 description=f.__name__,
                 sub_label=sub_label,
                 num_threads=threads).blocked_autorange()

results = []
for args in gen_inputs():
    results.append(benchmark(*args))

compare = Compare(results)
compare.trim_significant_figures()
compare.print()
```
</details>

<details>
<summary>
Test script exercising large number of combinations of left shift operands that I've used for further testing (validates results through comparing with results generated by NumPy)
</summary>

```python
import numpy as np
import torch

# Testing shifting of non-negative numbers only, but will test all
# possible RHS shift values for given type.  For int8 and int16, we'll
# test shifting all of non-negative values represntable by type.  For
# the rest of data types, we'll test shifting some random numbers in
# the corresponding range.
def _create_inputs(dtype):
    info = torch.iinfo(dtype)
    if dtype == torch.int8 or dtype == torch.int16:
        ntests = info.max + 1
        x = torch.arange(info.max + 1, dtype=dtype, device="cpu", requires_grad=False)
    else:
        ntests = 100000
        x = torch.randint(info.max + 1 if dtype != torch.int64 else info.max, (ntests,), dtype=dtype, device="cpu", requires_grad=False)
    y = torch.tensor(range(info.bits), dtype=dtype, device="cpu", requires_grad=False)
    xy = torch.cartesian_prod(x, y)
    return (xy[:, 0], xy[:, 1])

torch.manual_seed(0)

# Perform testing for each datatype supported, and compare results
# with ones generated by numpy.
for dtype in (torch.int8, torch.int16, torch.int32, torch.int64):
    (x, y) = _create_inputs(dtype)
    z = x << y
    xnp = x.numpy()
    ynp = y.numpy()
    znp = z.numpy()
    assert((znp == (xnp << ynp)).all())
```
</details>

<details>
<summary>
Benchmarking script running the left shift operator on tensors of different length (and varying number of bits to shift)
</summary>

```python
import torch
import pickle
import itertools
from torch.utils.benchmark import Timer, Compare

torch.manual_seed(0)

# Edit this part if needed.
lengths = [1024, 4096, 16384, 65536]
rhss = [1, 2, 7, 8, 15, 16, 31, 32, 63, 64]

benchmark_name = "lshift"
label = ""
dtypes = [torch.int8, torch.int16, torch.int32, torch.int64]
results = []

# Create an argument pair for testing.  Argument are tensors of given
# datatype and length, LHS for each shift operation is a random
# number, and RHS is given value that is same for all of them.
def _make_args(dtype, length, rhs):
    info = torch.iinfo(dtype)
    imax = info.max
    return (torch.randint(info.max, (length,), dtype=dtype, device="cpu", requires_grad=False),
            rhs * torch.ones((length,), dtype=dtype, device="cpu", requires_grad=False))

# Run shift operation for vectors of given lenghts and for given
# number of bits to be shifted, and remember timings.
for dtype, length, rhs in itertools.product(dtypes, lengths, rhss):
    x, y = _make_args(dtype, length, rhs)
    timer = Timer("x << y",
                  globals=globals(),
                  label=benchmark_name,
                  description=label,
                  sub_label=f"dtype={dtype},length={length}",
                  num_threads=1)
    results.append(timer.blocked_autorange())

# Gather results.
compare = Compare(results)
compare.trim_significant_figures()
compare.print()

# Print results.
with open("{}.pickle".format(label), "wb") as f:
    pickle.dump(results, f)
```
</details>

<details>
<summary>
Results of running above benchmarking script - results manually merged for runs of viable/strict (labeled "master" in the table below) and my branch (labeled "mybranch" in the table below)
</summary>

```
[------------------- lshift -------------------------------]
                                      |  master	|  mybranch
1 threads: ------------------------------------------------
      dtype=torch.int8,length=1024    |     3  	|      3
      dtype=torch.int8,length=4096    |     5  	|      3
      dtype=torch.int8,length=16384   |    14  	|      5
      dtype=torch.int8,length=65536   |    51  	|     15
      dtype=torch.int16,length=1024   |     3  	|      3
      dtype=torch.int16,length=4096   |     4  	|      3
      dtype=torch.int16,length=16384  |    11  	|      5
      dtype=torch.int16,length=65536  |    39  	|     13
      dtype=torch.int32,length=1024   |     3  	|      2
      dtype=torch.int32,length=4096   |     4  	|      3
      dtype=torch.int32,length=16384  |    10  	|      4
      dtype=torch.int32,length=65536  |    35  	|     12
      dtype=torch.int64,length=1024   |     3  	|      3
      dtype=torch.int64,length=4096   |     4  	|      3
      dtype=torch.int64,length=16384  |    11  	|      6
      dtype=torch.int64,length=65536  |    36  	|     20

Times are in microseconds (us).
```
</details>

All of the testing/benchmarking was conducted on qpu3, that supports AVX2 only.  For basic validation of AVX-512 update of left shift implementation for 8-bit operands (that is the only one that is non-trivial in AVX-512 case), [Compiler Explorer](https://godbolt.org/) is used, with GCC trunk and `-mavx512f -mavx512bw` flags added.  Here are further details:

<details>
<summary>
C program used for basic validation of AVX-512 vectorized version for 8-bit operands
</summary>

```
#include <stdio.h>
#include <stdint.h>
#include <string.h>

#include <immintrin.h>

static void print_m512i_int8(const __m512i* x)
{
    int8_t val[64];
    memcpy(val, x, sizeof(val));
    for (int i = 0; i < 64; ++i) {
        if (i > 0)
            printf(", ");
        printf("%d", (int)val[i]);
    }
    printf("\n");
}

int main()
{
    __m512i a = _mm512_set_epi8(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                                1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                                1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                                1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                                1);
    __m512i b = _mm512_set_epi8(7, 7, 7, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6,
                                5, 5, 5, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4,
                                3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2,
                                1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
                                0);

  // ------- Copied code from vec512_int.h

  // Mask used to set upper 8 bits of each 16-bit value to 0, and keep
  // lower 8 bits.
  __m512i mask = _mm512_set_epi16(0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff);

  // Convert 8-bit operands from lower lanes to 16-bit values, and
  // perform vectorized shift.  Make sure that upper 8 bits of 16-bit
  // results are all 0.
  __m256i a_lo_8 = _mm512_extracti64x4_epi64(a, 0);
  __m256i b_lo_8 = _mm512_extracti64x4_epi64(b, 0);
  __m512i a_lo_16 = _mm512_cvtepi8_epi16(a_lo_8);
  __m512i b_lo_16 = _mm512_cvtepi8_epi16(b_lo_8);
  __m512i c_lo_16 = _mm512_and_si512(_mm512_sllv_epi16(a_lo_16, b_lo_16), mask);

  // Convert 8-bit operands from upper lanes to 16-bit values, and
  // perform vectorized shift.  Make sure that upper 8 bits of 16-bit
  // results are all 0.
  __m256i a_hi_8 = _mm512_extracti64x4_epi64(a, 1);
  __m256i b_hi_8 = _mm512_extracti64x4_epi64(b, 1);
  __m512i a_hi_16 = _mm512_cvtepi8_epi16(a_hi_8);
  __m512i b_hi_16 = _mm512_cvtepi8_epi16(b_hi_8);
  __m512i c_hi_16 = _mm512_and_si512(_mm512_sllv_epi16(a_hi_16, b_hi_16), mask);

  // Cast 16-bit results back into 8-bit values and merge them
  // together (using unsigned saturation with higher 8 bits set to 0
  // above ensures that results are correct).  Values are merged per
  // lanes, so this is not yet the final result.
  __m512i c_perm = _mm512_packus_epi16(c_lo_16, c_hi_16);

  // Permute values so that final result is produced.
  __m512i idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0);
  __m512i c = _mm512_permutexvar_epi64(idx, c_perm);

  // ------- End copied

    print_m512i_int8(&c);
    // Expected output: 1(x8), 2(x8), 4(x8), 8(x8), 16(x8), 32(x8), 64(x8), 128(x8), -128(x8)

    return 0;
}
```
</details>

Pull Request resolved: pytorch#88607
Approved by: https://github.com/jgong5, https://github.com/lezcano, https://github.com/peterbell10
  • Loading branch information
alexsamardzic authored and pytorchmergebot committed Nov 13, 2022
1 parent df1df9d commit 2aca97c
Show file tree
Hide file tree
Showing 4 changed files with 308 additions and 4 deletions.
195 changes: 195 additions & 0 deletions aten/src/ATen/cpu/vec/vec256/vec256_int.h
Original file line number Diff line number Diff line change
Expand Up @@ -1133,6 +1133,201 @@ inline Vectorized<int8_t> Vectorized<int8_t>::le(const Vectorized<int8_t>& other
return (*this <= other) & Vectorized<int8_t>(1);
}

template <bool left_shift>
Vectorized<int16_t> inline shift_256_16(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
// No vector instruction for shifting int16_t, so emulating it instead.

// Control masks for shuffle operation, treating 256 bits as an
// array of 16-bit elements, and considering pairs of neighboring
// elements. Specifially, a mask named "ctl_M_N" (M,N in [0,1], and
// M!=N) is set so that shuffle will move element with index M from
// input pair into element with index N in output pair, and element
// with index M in output pair will be set to all 0s.
__m256i ctl_0_1 = _mm256_set_epi8(29, 28, 0x80, 0x80, 25, 24, 0x80, 0x80,
21, 20, 0x80, 0x80, 17, 16, 0x80, 0x80,
13, 12, 0x80, 0x80, 9, 8, 0x80, 0x80,
5, 4, 0x80, 0x80, 1, 0, 0x80, 0x80);
__m256i ctl_1_0 = _mm256_set_epi8(0x80, 0x80, 31, 30, 0x80, 0x80, 27, 26,
0x80, 0x80, 23, 22, 0x80, 0x80, 19, 18,
0x80, 0x80, 15, 14, 0x80, 0x80, 11, 10,
0x80, 0x80, 7, 6, 0x80, 0x80, 3, 2);

// Masks for bitwise and operation, treating 256 bits as an array of
// 16-bit elements, and considering them in pairs of neighboring
// elements. A mask named "keep_M" (M in [0,1]) is set so that
// bitwise and will copy element with index M from input pair into
// element with the same index in output pair, while the other
// element in output pair will be set to all 0s.
__m256i keep_0 = _mm256_set1_epi32(0xFFFF);
__m256i keep_1 = _mm256_set1_epi32(0xFFFF0000);

// Take each 16-bit element with idx%2==0 from input array to be
// shifted and extend it to 32 bits so that 0s are added to the
// right. Then, perform shifting on this 32-bit number. Upper 16
// bits will be proper result of shifting original 16-bit number, so
// write them to result array, into the same position from which
// corresponding input element is taken. Also, make sure that
// result array elements with idx%2!=0 are set to all 0s.
//
// Note that number of bits to shift for is extended to 32 bits by
// adding 0s to the left. That means this number is not properly
// sign-extended for negative values. However, number of bits to
// shift is treated as an unsigned integer by respective shift
// intrinsics anyway so if negative then either with or without
// proper sign extension, it will be interpreted as a number greater
// than 32, and the shifting result will be the same.
__m256i a0 = _mm256_shuffle_epi8(a, ctl_0_1);
__m256i b0 = _mm256_and_si256(b, keep_0);
__m256i c0;
if (left_shift)
c0 = _mm256_sllv_epi32(a0, b0);
c0 = _mm256_shuffle_epi8(c0, ctl_1_0);

// Peform shifting the same way for input array elements with
// idx%2==1.
__m256i a1 = _mm256_and_si256(a, keep_1);
__m256i b1 = _mm256_shuffle_epi8(b, ctl_1_0);
__m256i c1;
if (left_shift)
c1 = _mm256_sllv_epi32(a1, b1);
c1 = _mm256_and_si256(c1, keep_1);

// Merge partial results into the final result.
__m256i c = _mm256_or_si256(c0, c1);

return c;
}

template <bool left_shift>
Vectorized<int8_t> inline shift_256_8(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) {
// No vector instruction for shifting int8_t, so emulating it instead.

// Control masks for shuffle operation, treating 256 bits as an
// array of 8-bit elements, and considering quadruples of
// neighboring elements. Specifially, a mask named "ctl_M_N" (M,N
// in [0,1,2,3], and M!=N) is set so that shuffle will move element
// with index M from input quadruple into element with index N in
// output quadruple, and other elements in output quadruple will be
// set to all 0s.
__m256i ctl_0_3 = _mm256_set_epi8(28, 0x80, 0x80, 0x80, 24, 0x80, 0x80, 0x80,
20, 0x80, 0x80, 0x80, 16, 0x80, 0x80, 0x80,
12, 0x80, 0x80, 0x80, 8, 0x80, 0x80, 0x80,
4, 0x80, 0x80, 0x80, 0, 0x80, 0x80, 0x80);
__m256i ctl_1_0 = _mm256_set_epi8(0x80, 0x80, 0x80, 29, 0x80, 0x80, 0x80, 25,
0x80, 0x80, 0x80, 21, 0x80, 0x80, 0x80, 17,
0x80, 0x80, 0x80, 13, 0x80, 0x80, 0x80, 9,
0x80, 0x80, 0x80, 5, 0x80, 0x80, 0x80, 1);
__m256i ctl_1_3 = _mm256_set_epi8(29, 0x80, 0x80, 0x80, 25, 0x80, 0x80, 0x80,
21, 0x80, 0x80, 0x80, 17, 0x80, 0x80, 0x80,
13, 0x80, 0x80, 0x80, 9, 0x80, 0x80, 0x80,
5, 0x80, 0x80, 0x80, 1, 0x80, 0x80, 0x80);
__m256i ctl_2_0 = _mm256_set_epi8(0x80, 0x80, 0x80, 30, 0x80, 0x80, 0x80, 26,
0x80, 0x80, 0x80, 22, 0x80, 0x80, 0x80, 18,
0x80, 0x80, 0x80, 14, 0x80, 0x80, 0x80, 10,
0x80, 0x80, 0x80, 6, 0x80, 0x80, 0x80, 2);
__m256i ctl_2_3 = _mm256_set_epi8(30, 0x80, 0x80, 0x80, 26, 0x80, 0x80, 0x80,
22, 0x80, 0x80, 0x80, 18, 0x80, 0x80, 0x80,
14, 0x80, 0x80, 0x80, 10, 0x80, 0x80, 0x80,
6, 0x80, 0x80, 0x80, 2, 0x80, 0x80, 0x80);
__m256i ctl_3_0 = _mm256_set_epi8(0x80, 0x80, 0x80, 31, 0x80, 0x80, 0x80, 27,
0x80, 0x80, 0x80, 23, 0x80, 0x80, 0x80, 19,
0x80, 0x80, 0x80, 15, 0x80, 0x80, 0x80, 11,
0x80, 0x80, 0x80, 7, 0x80, 0x80, 0x80, 3);
__m256i ctl_3_1 = _mm256_set_epi8(0x80, 0x80, 31, 0x80, 0x80, 0x80, 27, 0x80,
0x80, 0x80, 23, 0x80, 0x80, 0x80, 19, 0x80,
0x80, 0x80, 15, 0x80, 0x80, 0x80, 11, 0x80,
0x80, 0x80, 7, 0x80, 0x80, 0x80, 3, 0x80);
__m256i ctl_3_2 = _mm256_set_epi8(0x80, 31, 0x80, 0x80, 0x80, 27, 0x80, 0x80,
0x80, 23, 0x80, 0x80, 0x80, 19, 0x80, 0x80,
0x80, 15, 0x80, 0x80, 0x80, 11, 0x80, 0x80,
0x80, 7, 0x80, 0x80, 0x80, 3, 0x80, 0x80);

// Masks for bitwise and operation, treating 256 bits as an array of
// 8-bit elements, and considering them in quadruples of neighboring
// elements. A mask named "keep_M" (M in [0,1,2,3]) is set so that
// bitwise and will copy element with index M from input quadruple
// into element with the same index in output quadruple, while the
// other elements in output quadruple will be set to all 0s.
__m256i keep_0 = _mm256_set1_epi32(0xFF);
__m256i keep_3 = _mm256_set1_epi32(0xFF000000);

// Take each 8-bit element with idx%4==0 from input array to be
// shifted and extend it to 32 bits so that 0s are added to the
// right. Then, perform shifting on this 32-bit number. Upper 8
// bits will be proper result of shifting original 8-bit number, so
// write them to result array, into the same position from which
// corresponding input element is taken. Also, make sure that
// result array elements with idx%4!=0 are set to all 0s.
//
// Note that number of bits to shift for is extended to 32 bits by
// adding 0s to the left. That means this number is not properly
// sign-extended for negative values. However, number of bits to
// shift is treated as an unsigned integer by respective shift
// intrinsics anyway so if negative then either with or without
// proper sign extension, it will be interpreted as a number greater
// than 32, and the shifting result will be the same.
__m256i a0 = _mm256_shuffle_epi8(a, ctl_0_3);
__m256i b0 = _mm256_and_si256(b, keep_0);
__m256i c0;
if (left_shift)
c0 = _mm256_sllv_epi32(a0, b0);
c0 = _mm256_shuffle_epi8(c0, ctl_3_0);

// Peform shifting the same way for input array elements with
// idx%4==1.
__m256i a1 = _mm256_shuffle_epi8(a, ctl_1_3);
__m256i b1 = _mm256_shuffle_epi8(b, ctl_1_0);
__m256i c1;
if (left_shift)
c1 = _mm256_sllv_epi32(a1, b1);
c1 = _mm256_shuffle_epi8(c1, ctl_3_1);

// Peform shifting the same way for input array elements with
// idx%4==2.
__m256i a2 = _mm256_shuffle_epi8(a, ctl_2_3);
__m256i b2 = _mm256_shuffle_epi8(b, ctl_2_0);
__m256i c2;
if (left_shift)
c2 = _mm256_sllv_epi32(a2, b2);
c2 = _mm256_shuffle_epi8(c2, ctl_3_2);

// Peform shifting the same way for input array elements with
// idx%4==3.
__m256i a3 = _mm256_and_si256(a, keep_3);
__m256i b3 = _mm256_shuffle_epi8(b, ctl_3_0);
__m256i c3;
if (left_shift)
c3 = _mm256_sllv_epi32(a3, b3);
c3 = _mm256_and_si256(c3, keep_3);

// Merge partial results into the final result.
__m256i c01 = _mm256_or_si256(c0, c1);
__m256i c23 = _mm256_or_si256(c2, c3);
__m256i c = _mm256_or_si256(c01, c23);

return c;
}

template <>
Vectorized<int64_t> inline operator<<(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
return _mm256_sllv_epi64(a, b);
}

template <>
Vectorized<int32_t> inline operator<<(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
return _mm256_sllv_epi32(a, b);
}

template <>
Vectorized<int16_t> inline operator<<(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
return shift_256_16<true>(a, b);
}

template <>
Vectorized<int8_t> inline operator<<(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) {
return shift_256_8<true>(a, b);
}

#endif

}}}
93 changes: 93 additions & 0 deletions aten/src/ATen/cpu/vec/vec512/vec512_int.h
Original file line number Diff line number Diff line change
Expand Up @@ -1163,6 +1163,99 @@ inline Vectorized<int8_t> Vectorized<int8_t>::le(const Vectorized<int8_t>& other
return (*this <= other) & Vectorized<int8_t>(1);
}

template <bool left_shift>
Vectorized<int8_t> inline shift_512_8(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) {
// No vector instruction for shifting int8_t, so emulating it instead.

// Control masks for shuffle operation, treating 512 bits as an
// array of 8-bit elements, and considering pairs of neighboring
// elements. Specifially, a mask named "ctl_M_N" (M,N in [0,1], and
// M!=N) is set so that shuffle will move element with index M from
// input pair into element with index N in output pair, and element
// with index M in output pair will be set to all 0s.
__m512i ctl_0_1 = _mm512_set_epi8(62, 0x80, 60, 0x80, 58, 0x80, 56, 0x80,
54, 0x80, 52, 0x80, 50, 0x80, 48, 0x80,
46, 0x80, 44, 0x80, 42, 0x80, 40, 0x80,
38, 0x80, 36, 0x80, 34, 0x80, 32, 0x80,
30, 0x80, 28, 0x80, 26, 0x80, 24, 0x80,
22, 0x80, 20, 0x80, 18, 0x80, 16, 0x80,
14, 0x80, 12, 0x80, 10, 0x80, 8, 0x80,
6, 0x80, 4, 0x80, 2, 0x80, 0, 0x80);
__m512i ctl_1_0 = _mm512_set_epi8(0x80, 63, 0x80, 61, 0x80, 59, 0x80, 57,
0x80, 55, 0x80, 53, 0x80, 51, 0x80, 49,
0x80, 47, 0x80, 45, 0x80, 43, 0x80, 41,
0x80, 39, 0x80, 37, 0x80, 35, 0x80, 33,
0x80, 31, 0x80, 29, 0x80, 27, 0x80, 25,
0x80, 23, 0x80, 21, 0x80, 19, 0x80, 17,
0x80, 15, 0x80, 13, 0x80, 11, 0x80, 9,
0x80, 7, 0x80, 5, 0x80, 3, 0x80, 1);

// Masks for bitwise and operation, treating 512 bits as an array of
// 8-bit elements, and considering them in pairs of neighboring
// elements. A mask named "keep_M" (M in [0,1]) is set so that
// bitwise and will copy element with index M from input pair into
// element with the same index in output pair, while the other
// element in output pair will be set to all 0s.
__m512i keep_0 = _mm512_set1_epi16(0xFF);
__m512i keep_1 = _mm512_set1_epi16(0xFF00);

// Take each 8-bit element with idx%2==0 from input array to be
// shifted and extend it to 16 bits so that 0s are added to the
// right. Then, perform shifting on this 16-bit number. Upper 8
// bits will be proper result of shifting original 8-bit number, so
// write them to result array, into the same position from which
// corresponding input element is taken. Also, make sure that
// result array elements with idx%2!=0 are set to all 0s.
//
// Note that number of bits to shift for is extended to 16 bits by
// adding 0s to the left. That means this number is not properly
// sign-extended for negative values. However, number of bits to
// shift is treated as an unsigned integer by respective shift
// intrinsics anyway so if negative then either with or without
// proper sign extension, it will be interpreted as a number greater
// than 32, and the shifting result will be the same.
__m512i a0 = _mm512_shuffle_epi8(a, ctl_0_1);
__m512i b0 = _mm512_and_si512(b, keep_0);
__m512i c0;
if (left_shift)
c0 = _mm512_sllv_epi16(a0, b0);
c0 = _mm512_shuffle_epi8(c0, ctl_1_0);

// Peform shifting the same way for input array elements with
// idx%2==1.
__m512i a1 = _mm512_and_si512(a, keep_1);
__m512i b1 = _mm512_shuffle_epi8(b, ctl_1_0);
__m512i c1;
if (left_shift)
c1 = _mm512_sllv_epi16(a1, b1);
c1 = _mm512_and_si512(c1, keep_1);

// Merge partial results into the final result.
__m512i c = _mm512_or_si512(c0, c1);

return c;
}

template <>
Vectorized<int64_t> inline operator<<(const Vectorized<int64_t>& a, const Vectorized<int64_t>& b) {
return _mm512_sllv_epi64(a, b);
}

template <>
Vectorized<int32_t> inline operator<<(const Vectorized<int32_t>& a, const Vectorized<int32_t>& b) {
return _mm512_sllv_epi32(a, b);
}

template <>
Vectorized<int16_t> inline operator<<(const Vectorized<int16_t>& a, const Vectorized<int16_t>& b) {
return _mm512_sllv_epi16(a, b);
}

template <>
Vectorized<int8_t> inline operator<<(const Vectorized<int8_t>& a, const Vectorized<int8_t>& b) {
return shift_512_8<true>(a, b);
}

#endif

}}}
13 changes: 13 additions & 0 deletions aten/src/ATen/cpu/vec/vec_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,13 @@ inline Vectorized<T> operator~(const Vectorized<T>& a) {
return a ^ ones;
}

template <class T> Vectorized<T> inline operator<<(const Vectorized<T> &a, const Vectorized<T> &b) {
Vectorized<T> c;
for (int i = 0; i != Vectorized<T>::size(); i++) {
c[i] = a[i] << b[i];
}
return c;
}

template <typename T>
inline Vectorized<T>& operator += (Vectorized<T>& a, const Vectorized<T>& b) {
Expand Down Expand Up @@ -826,6 +833,12 @@ inline Vectorized<T>& operator *= (Vectorized<T>& a, const Vectorized<T>& b) {
return a;
}

template <typename T>
inline Vectorized<T>& operator <<= (Vectorized<T>& a, const Vectorized<T>& b) {
a = a << b;
return a;
}

template <typename T>
inline Vectorized<T> fmadd(const Vectorized<T>& a, const Vectorized<T>& b, const Vectorized<T>& c) {
return a * b + c;
Expand Down
11 changes: 7 additions & 4 deletions aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,10 +314,13 @@ void bitwise_xor_kernel(TensorIteratorBase& iter) {

void lshift_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "lshift_cpu", [&]() {
cpu_kernel(iter,
[](scalar_t a, scalar_t b) -> scalar_t {
return static_cast<std::make_unsigned_t<scalar_t>>(a) << b;
});
cpu_kernel_vec(iter,
[](scalar_t a, scalar_t b) -> scalar_t {
return static_cast<std::make_unsigned_t<scalar_t>>(a) << b;
},
[](Vectorized<scalar_t> a, Vectorized<scalar_t> b) {
return a << b;
});
});
}

Expand Down

0 comments on commit 2aca97c

Please sign in to comment.