Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Vectorized CPU code implementing left shift operator. (pytorch#88607)
This PR adds vectorized implementation for CPU version of left shift operator. All of the tests run by `pytest test/test_ops.py -vk left_shift` pass. Here are some additional details: <details> <summary> Benchmarking script (writen by Philip, with small tweaks by Mario) comparing left shifts with multiplications - on par now </summary> ```python import torch from torch import Tensor from torch.utils.benchmark import Timer, Compare from itertools import product from functools import partial # These functions exist, because torch.jit.script does not support `torch.iinfo` def _num_value_bits(dtype): if dtype == torch.uint8: return 8 else: # torch.int32 return 31 def _max_value(dtype): if dtype == torch.uint8: return 255 else: # torch.int32 return 2147483647 def bitshift(image, dtype): num_value_bits_input = _num_value_bits(image.dtype) num_value_bits_output = _num_value_bits(dtype) return image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input) def mul(image, dtype): input_max = float(_max_value(image.dtype)) output_max = float(_max_value(dtype)) factor = int((output_max + 1) // (input_max + 1)) image = image.to(dtype) return image * factor size = 256 image = torch.randint(0, 256, (3, size, size), dtype=torch.uint8) dtype = torch.int32 def gen_inputs(): devices = ("cpu",) fns = (mul, bitshift) threads = (1,) for device, fn, threads in product(devices, fns, threads): yield f"Bitshift {device} {image.dtype}", str(tuple(image.shape)), threads, fn, image, dtype def benchmark(label, sub_label, threads, f, *args, **kwargs): return Timer("f(*args, **kwargs)", globals=locals(), label=label, description=f.__name__, sub_label=sub_label, num_threads=threads).blocked_autorange() results = [] for args in gen_inputs(): results.append(benchmark(*args)) compare = Compare(results) compare.trim_significant_figures() compare.print() ``` </details> <details> <summary> Test script exercising large number of combinations of left shift operands that I've used for further testing (validates results through comparing with results generated by NumPy) </summary> ```python import numpy as np import torch # Testing shifting of non-negative numbers only, but will test all # possible RHS shift values for given type. For int8 and int16, we'll # test shifting all of non-negative values represntable by type. For # the rest of data types, we'll test shifting some random numbers in # the corresponding range. def _create_inputs(dtype): info = torch.iinfo(dtype) if dtype == torch.int8 or dtype == torch.int16: ntests = info.max + 1 x = torch.arange(info.max + 1, dtype=dtype, device="cpu", requires_grad=False) else: ntests = 100000 x = torch.randint(info.max + 1 if dtype != torch.int64 else info.max, (ntests,), dtype=dtype, device="cpu", requires_grad=False) y = torch.tensor(range(info.bits), dtype=dtype, device="cpu", requires_grad=False) xy = torch.cartesian_prod(x, y) return (xy[:, 0], xy[:, 1]) torch.manual_seed(0) # Perform testing for each datatype supported, and compare results # with ones generated by numpy. for dtype in (torch.int8, torch.int16, torch.int32, torch.int64): (x, y) = _create_inputs(dtype) z = x << y xnp = x.numpy() ynp = y.numpy() znp = z.numpy() assert((znp == (xnp << ynp)).all()) ``` </details> <details> <summary> Benchmarking script running the left shift operator on tensors of different length (and varying number of bits to shift) </summary> ```python import torch import pickle import itertools from torch.utils.benchmark import Timer, Compare torch.manual_seed(0) # Edit this part if needed. lengths = [1024, 4096, 16384, 65536] rhss = [1, 2, 7, 8, 15, 16, 31, 32, 63, 64] benchmark_name = "lshift" label = "" dtypes = [torch.int8, torch.int16, torch.int32, torch.int64] results = [] # Create an argument pair for testing. Argument are tensors of given # datatype and length, LHS for each shift operation is a random # number, and RHS is given value that is same for all of them. def _make_args(dtype, length, rhs): info = torch.iinfo(dtype) imax = info.max return (torch.randint(info.max, (length,), dtype=dtype, device="cpu", requires_grad=False), rhs * torch.ones((length,), dtype=dtype, device="cpu", requires_grad=False)) # Run shift operation for vectors of given lenghts and for given # number of bits to be shifted, and remember timings. for dtype, length, rhs in itertools.product(dtypes, lengths, rhss): x, y = _make_args(dtype, length, rhs) timer = Timer("x << y", globals=globals(), label=benchmark_name, description=label, sub_label=f"dtype={dtype},length={length}", num_threads=1) results.append(timer.blocked_autorange()) # Gather results. compare = Compare(results) compare.trim_significant_figures() compare.print() # Print results. with open("{}.pickle".format(label), "wb") as f: pickle.dump(results, f) ``` </details> <details> <summary> Results of running above benchmarking script - results manually merged for runs of viable/strict (labeled "master" in the table below) and my branch (labeled "mybranch" in the table below) </summary> ``` [------------------- lshift -------------------------------] | master | mybranch 1 threads: ------------------------------------------------ dtype=torch.int8,length=1024 | 3 | 3 dtype=torch.int8,length=4096 | 5 | 3 dtype=torch.int8,length=16384 | 14 | 5 dtype=torch.int8,length=65536 | 51 | 15 dtype=torch.int16,length=1024 | 3 | 3 dtype=torch.int16,length=4096 | 4 | 3 dtype=torch.int16,length=16384 | 11 | 5 dtype=torch.int16,length=65536 | 39 | 13 dtype=torch.int32,length=1024 | 3 | 2 dtype=torch.int32,length=4096 | 4 | 3 dtype=torch.int32,length=16384 | 10 | 4 dtype=torch.int32,length=65536 | 35 | 12 dtype=torch.int64,length=1024 | 3 | 3 dtype=torch.int64,length=4096 | 4 | 3 dtype=torch.int64,length=16384 | 11 | 6 dtype=torch.int64,length=65536 | 36 | 20 Times are in microseconds (us). ``` </details> All of the testing/benchmarking was conducted on qpu3, that supports AVX2 only. For basic validation of AVX-512 update of left shift implementation for 8-bit operands (that is the only one that is non-trivial in AVX-512 case), [Compiler Explorer](https://godbolt.org/) is used, with GCC trunk and `-mavx512f -mavx512bw` flags added. Here are further details: <details> <summary> C program used for basic validation of AVX-512 vectorized version for 8-bit operands </summary> ``` #include <stdio.h> #include <stdint.h> #include <string.h> #include <immintrin.h> static void print_m512i_int8(const __m512i* x) { int8_t val[64]; memcpy(val, x, sizeof(val)); for (int i = 0; i < 64; ++i) { if (i > 0) printf(", "); printf("%d", (int)val[i]); } printf("\n"); } int main() { __m512i a = _mm512_set_epi8(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); __m512i b = _mm512_set_epi8(7, 7, 7, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0); // ------- Copied code from vec512_int.h // Mask used to set upper 8 bits of each 16-bit value to 0, and keep // lower 8 bits. __m512i mask = _mm512_set_epi16(0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff); // Convert 8-bit operands from lower lanes to 16-bit values, and // perform vectorized shift. Make sure that upper 8 bits of 16-bit // results are all 0. __m256i a_lo_8 = _mm512_extracti64x4_epi64(a, 0); __m256i b_lo_8 = _mm512_extracti64x4_epi64(b, 0); __m512i a_lo_16 = _mm512_cvtepi8_epi16(a_lo_8); __m512i b_lo_16 = _mm512_cvtepi8_epi16(b_lo_8); __m512i c_lo_16 = _mm512_and_si512(_mm512_sllv_epi16(a_lo_16, b_lo_16), mask); // Convert 8-bit operands from upper lanes to 16-bit values, and // perform vectorized shift. Make sure that upper 8 bits of 16-bit // results are all 0. __m256i a_hi_8 = _mm512_extracti64x4_epi64(a, 1); __m256i b_hi_8 = _mm512_extracti64x4_epi64(b, 1); __m512i a_hi_16 = _mm512_cvtepi8_epi16(a_hi_8); __m512i b_hi_16 = _mm512_cvtepi8_epi16(b_hi_8); __m512i c_hi_16 = _mm512_and_si512(_mm512_sllv_epi16(a_hi_16, b_hi_16), mask); // Cast 16-bit results back into 8-bit values and merge them // together (using unsigned saturation with higher 8 bits set to 0 // above ensures that results are correct). Values are merged per // lanes, so this is not yet the final result. __m512i c_perm = _mm512_packus_epi16(c_lo_16, c_hi_16); // Permute values so that final result is produced. __m512i idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); __m512i c = _mm512_permutexvar_epi64(idx, c_perm); // ------- End copied print_m512i_int8(&c); // Expected output: 1(x8), 2(x8), 4(x8), 8(x8), 16(x8), 32(x8), 64(x8), 128(x8), -128(x8) return 0; } ``` </details> Pull Request resolved: pytorch#88607 Approved by: https://github.com/jgong5, https://github.com/lezcano, https://github.com/peterbell10
- Loading branch information