forked from libxsmm/libxsmm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
libxsmm_intrinsics_x86.h
1022 lines (951 loc) · 60.9 KB
/
libxsmm_intrinsics_x86.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/******************************************************************************
* Copyright (c) Intel Corporation - All rights reserved. *
* This file is part of the LIBXSMM library. *
* *
* For information on the license, see the LICENSE file. *
* Further information: https://github.com/hfp/libxsmm/ *
* SPDX-License-Identifier: BSD-3-Clause *
******************************************************************************/
/* Hans Pabst (Intel Corp.)
******************************************************************************/
#ifndef LIBXSMM_INTRINSICS_X86_H
#define LIBXSMM_INTRINSICS_X86_H
#include "libxsmm_cpuid.h"
/** Macro evaluates to LIBXSMM_ATTRIBUTE_TARGET_xxx (see below). */
#define LIBXSMM_ATTRIBUTE_TARGET(TARGET) LIBXSMM_CONCATENATE(LIBXSMM_ATTRIBUTE_TARGET_, TARGET)
#if !defined(LIBXSMM_INTRINSICS_NONE) && !defined(LIBXSMM_PLATFORM_X86)
# define LIBXSMM_INTRINSICS_NONE
#endif
#if /*no intrinsics: tested with 17.x and 18.x*/(defined(__PGI) && \
LIBXSMM_VERSION2(19, 0) > LIBXSMM_VERSION2(__PGIC__, __PGIC_MINOR__)) \
|| /*legacy*/(defined(_CRAYC) && !defined(__GNUC__))
# if !defined(LIBXSMM_INTRINSICS_NONE) && !defined(LIBXSMM_INTRINSICS_STATIC)
# define LIBXSMM_INTRINSICS_NONE
# endif
#elif !defined(LIBXSMM_INTRINSICS_STATIC) && !defined(LIBXSMM_INTRINSICS_NONE) && ( \
(defined(__GNUC__) && !defined(__clang__) && !defined(LIBXSMM_INTEL_COMPILER) && !defined(_CRAYC) && \
LIBXSMM_VERSION2(4, 4) > LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)) /* GCC 4.4 (target-attribute) */ \
|| (defined(__clang__) && LIBXSMM_VERSION2(3, 7) > LIBXSMM_VERSION2(__clang_major__, __clang_minor__)) \
|| (defined(__APPLE__) && defined(__MACH__) && !defined(LIBXSMM_INTEL_COMPILER) && defined(__clang__) && \
LIBXSMM_VERSION2(9, 0) > LIBXSMM_VERSION2(__clang_major__, __clang_minor__)))
# define LIBXSMM_INTRINSICS_STATIC
#endif
#if defined(LIBXSMM_OFFLOAD_TARGET)
# pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
#endif
/** https://github.com/intel/Immintrin-debug */
#if !defined(LIBXSMM_INTRINSICS_DEBUG) && 0
# define LIBXSMM_INTRINSICS_DEBUG
/* workarounds removed after LIBXSMM 1.16.1-1.16.1-1268 */
# include "immintrin_dbg.h"
#endif
#if defined(__MIC__) && !defined(LIBXSMM_INTRINSICS_NONE)
# if !defined(LIBXSMM_STATIC_TARGET_ARCH)
# define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_TARGET_ARCH_GENERIC
# endif
# define LIBXSMM_INTRINSICS(TARGET)
# define LIBXSMM_INTRINSICS_INCLUDE
#elif !defined(LIBXSMM_INTRINSICS_NONE) /*!defined(__MIC__)*/
# if defined(__AVX512F__) && defined(__AVX512CD__) \
&& defined(__AVX512DQ__) && defined(__AVX512BW__) && defined(__AVX512VL__) && defined(__AVX512VNNI__) && defined(__AVX512BF16__) \
&& defined(__AVX2__) && defined(__FMA__) && defined(__AVX__) && defined(__SSE4_2__) && defined(__SSE4_1__) && defined(__SSE3__) \
&& (!defined(__GNUC__) || defined(__clang__) || defined(LIBXSMM_INTEL_COMPILER) || defined(_CRAYC) /* TODO: check GCC, Clang, etc. */ \
|| (LIBXSMM_VERSION2(10, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__))) \
&& (!defined(__clang__) || (LIBXSMM_VERSION2( 9, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))) \
&& (!defined(__APPLE__) || !defined(__MACH__) || LIBXSMM_VERSION2(99, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))
# if !defined(LIBXSMM_STATIC_TARGET_ARCH)
# define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CPX
# endif
# define LIBXSMM_INTRINSICS_INCLUDE
# elif defined(__AVX512F__) && defined(__AVX512CD__) \
&& defined(__AVX512DQ__) && defined(__AVX512BW__) && defined(__AVX512VL__) && defined(__AVX512VNNI__) \
&& defined(__AVX2__) && defined(__FMA__) && defined(__AVX__) && defined(__SSE4_2__) && defined(__SSE4_1__) && defined(__SSE3__) \
&& (!defined(__GNUC__) || defined(__clang__) || defined(LIBXSMM_INTEL_COMPILER) || defined(_CRAYC) \
|| (LIBXSMM_VERSION2(8, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__))) \
&& (!defined(__clang__) || (LIBXSMM_VERSION2(6, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))) \
&& (!defined(__APPLE__) || !defined(__MACH__) || LIBXSMM_VERSION2(10, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))
# if !defined(LIBXSMM_STATIC_TARGET_ARCH)
# define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CLX
# endif
# define LIBXSMM_INTRINSICS_INCLUDE
# elif defined(__AVX512F__) && defined(__AVX512CD__) \
&& defined(__AVX512DQ__) && defined(__AVX512BW__) && defined(__AVX512VL__) \
&& defined(__AVX2__) && defined(__FMA__) && defined(__AVX__) && defined(__SSE4_2__) && defined(__SSE4_1__) && defined(__SSE3__) \
&& (!defined(__GNUC__) || defined(__clang__) || defined(LIBXSMM_INTEL_COMPILER) || defined(_CRAYC) \
|| (LIBXSMM_VERSION2(5, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__))) \
&& (!defined(__clang__) || (LIBXSMM_VERSION2(4, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))) \
&& (!defined(__APPLE__) || !defined(__MACH__) || LIBXSMM_VERSION2(9, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))
# if !defined(LIBXSMM_STATIC_TARGET_ARCH)
# define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CORE
# endif
# define LIBXSMM_INTRINSICS_INCLUDE
# elif defined(__AVX512F__) && defined(__AVX512CD__) \
&& defined(__AVX512PF__) && defined(__AVX512ER__) \
&& defined(__AVX2__) && defined(__FMA__) && defined(__AVX__) && defined(__SSE4_2__) && defined(__SSE4_1__) && defined(__SSE3__) \
&& (!defined(__GNUC__) || defined(__clang__) || defined(LIBXSMM_INTEL_COMPILER) || defined(_CRAYC) \
|| (LIBXSMM_VERSION2(5, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__))) \
&& (!defined(__clang__) || (LIBXSMM_VERSION2(4, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))) \
&& (!defined(__APPLE__) || !defined(__MACH__) || LIBXSMM_VERSION2(9, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))
# if !defined(LIBXSMM_STATIC_TARGET_ARCH)
# define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_MIC
# endif
# define LIBXSMM_INTRINSICS_INCLUDE
# elif defined(__AVX512F__) && defined(__AVX512CD__) \
&& defined(__AVX2__) && defined(__FMA__) && defined(__AVX__) && defined(__SSE4_2__) && defined(__SSE4_1__) && defined(__SSE3__) \
&& (!defined(__GNUC__) || defined(__clang__) || defined(LIBXSMM_INTEL_COMPILER) || defined(_CRAYC) \
|| (LIBXSMM_VERSION2(5, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__))) \
&& (!defined(__clang__) || (LIBXSMM_VERSION2(4, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))) \
&& (!defined(__APPLE__) || !defined(__MACH__) || LIBXSMM_VERSION2(9, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))
# if !defined(LIBXSMM_STATIC_TARGET_ARCH)
# define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512
# endif
# define LIBXSMM_INTRINSICS_INCLUDE
# elif defined(__AVX2__) && defined(__FMA__) && defined(__AVX__) && defined(__SSE4_2__) && defined(__SSE4_1__) && defined(__SSE3__)
# if !defined(LIBXSMM_STATIC_TARGET_ARCH)
# define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_AVX2
# endif
# define LIBXSMM_INTRINSICS_INCLUDE
# elif defined(__AVX__) && defined(__SSE4_2__) && defined(__SSE4_1__) && defined(__SSE3__)
# if !defined(LIBXSMM_STATIC_TARGET_ARCH)
# define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_AVX
# endif
# define LIBXSMM_INTRINSICS_INCLUDE
# elif defined(__SSE4_2__) && defined(__SSE4_1__) && defined(__SSE3__)
# if !defined(LIBXSMM_STATIC_TARGET_ARCH)
# define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_SSE42
# endif
# define LIBXSMM_INTRINSICS_INCLUDE
# elif defined(__SSE3__)
# if !defined(LIBXSMM_STATIC_TARGET_ARCH)
# define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_SSE3
# endif
# define LIBXSMM_INTRINSICS_INCLUDE
# elif defined(LIBXSMM_PLATFORM_X86)
# if !defined(LIBXSMM_STATIC_TARGET_ARCH)
# define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_GENERIC
# endif
# if defined(__GNUC__)
# define LIBXSMM_INTRINSICS_INCLUDE
# endif
# endif
# if defined(LIBXSMM_STATIC_TARGET_ARCH) && !defined(LIBXSMM_INTRINSICS_STATIC)
# if defined(__INTEL_COMPILER)
# if !defined(LIBXSMM_MAX_STATIC_TARGET_ARCH)
/* TODO: compiler version check for LIBXSMM_MAX_STATIC_TARGET_ARCH */
# if 1904 <= (LIBXSMM_INTEL_COMPILER) && !defined(_WIN32)
# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CPX
# elif 1801 <= (LIBXSMM_INTEL_COMPILER)
# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CLX
# elif 1500 <= (LIBXSMM_INTEL_COMPILER)
# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CORE
# elif 1400 <= (LIBXSMM_INTEL_COMPILER)
# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_MIC
# else
# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX2
# endif
# endif
# define LIBXSMM_INTRINSICS(TARGET)/*no need for target flags*/
# define LIBXSMM_INTRINSICS_INCLUDE
# elif defined(_CRAYC) && defined(__GNUC__)
/* TODO: version check, e.g., LIBXSMM_VERSION2(11, 5) <= LIBXSMM_VERSION2(_RELEASE, _RELEASE_MINOR) */
# if !defined(LIBXSMM_MAX_STATIC_TARGET_ARCH)
# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX
# endif
# define LIBXSMM_INTRINSICS(TARGET)/*no need for target flags*/
# define LIBXSMM_INTRINSICS_INCLUDE
# elif defined(_MSC_VER) && !defined(__clang__)
/* TODO: compiler version check for LIBXSMM_MAX_STATIC_TARGET_ARCH */
# if !defined(LIBXSMM_MAX_STATIC_TARGET_ARCH)
# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX2
# endif
# define LIBXSMM_INTRINSICS(TARGET)/*no need for target flags*/
# define LIBXSMM_INTRINSICS_INCLUDE
# elif (!defined(__GNUC__) || LIBXSMM_VERSION2(4, 9) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)) \
&& (!defined(__clang__) || LIBXSMM_VERSION2(4, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__)) \
&& (!defined(__APPLE__) || !defined(__MACH__)) && !defined(__PGI) && !defined(_MSC_VER)
# if !defined(LIBXSMM_MAX_STATIC_TARGET_ARCH)
# if defined(__CYGWIN__) && !defined(LIBXSMM_INTRINSICS_DEBUG) /* Cygwin: invalid register for .seh_savexmm */
# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX2
# elif (defined(__clang__) && LIBXSMM_VERSION2(10, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))
# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CPX
# elif (defined(__GNUC__) && LIBXSMM_VERSION2(10, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)) \
|| (defined(__clang__) && LIBXSMM_VERSION2( 9, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__) && !defined(__cray__))
# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CPX
# elif (defined(__GNUC__) && LIBXSMM_VERSION2(8, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)) \
|| (defined(__clang__) && LIBXSMM_VERSION2(6, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))
# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CLX
# elif (defined(__GNUC__) && LIBXSMM_VERSION2(5, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)) \
|| (defined(__clang__) && LIBXSMM_VERSION2(6, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))
# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CORE
# else
# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX2
# endif
# endif
# define LIBXSMM_INTRINSICS_INCLUDE
# else /* GCC/legacy incl. Clang */
# if defined(__clang__) && !(defined(__APPLE__) && defined(__MACH__)) && !defined(_WIN32)
# if (LIBXSMM_VERSION2(7, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__)) /* TODO */
/* no limitations */
# elif (LIBXSMM_VERSION2(4, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))
# if !defined(LIBXSMM_INTRINSICS_STATIC) && (LIBXSMM_STATIC_TARGET_ARCH < LIBXSMM_X86_AVX2/*workaround*/)
# define LIBXSMM_INTRINSICS_STATIC
# endif
# elif !defined(LIBXSMM_INTRINSICS_STATIC)
# define LIBXSMM_INTRINSICS_STATIC
# endif
# if !defined(LIBXSMM_MAX_STATIC_TARGET_ARCH)
# if defined(__CYGWIN__) && !defined(LIBXSMM_INTRINSICS_DEBUG) /* Cygwin: invalid register for .seh_savexmm */
# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX2
# elif LIBXSMM_VERSION2(10, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__)
# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CPX
# elif LIBXSMM_VERSION2( 9, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__) && !defined(__cray__)
# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CPX
# elif LIBXSMM_VERSION2( 6, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__)
# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CLX
# else
# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CORE
# endif
# endif
# else /* fallback */
# if !defined(LIBXSMM_MAX_STATIC_TARGET_ARCH)
# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_STATIC_TARGET_ARCH
# endif
# if !defined(LIBXSMM_INTRINSICS_STATIC) && (LIBXSMM_STATIC_TARGET_ARCH < LIBXSMM_X86_AVX2/*workaround*/)
# define LIBXSMM_INTRINSICS_STATIC
# endif
# endif
# if !defined(LIBXSMM_INTRINSICS_INCLUDE) && (!defined(__PGI) || LIBXSMM_VERSION2(19, 0) <= LIBXSMM_VERSION2(__PGIC__, __PGIC_MINOR__))
# define LIBXSMM_INTRINSICS_INCLUDE
# endif
# endif /* GCC/legacy incl. Clang */
# if !defined(LIBXSMM_MAX_STATIC_TARGET_ARCH)
# error "LIBXSMM_MAX_STATIC_TARGET_ARCH not defined!"
# endif
# if defined(LIBXSMM_TARGET_ARCH) && (LIBXSMM_TARGET_ARCH < LIBXSMM_MAX_STATIC_TARGET_ARCH)
# undef LIBXSMM_MAX_STATIC_TARGET_ARCH
# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_TARGET_ARCH
# endif
# if defined(LIBXSMM_INTRINSICS_INCLUDE) && !defined(LIBXSMM_INTRINSICS_NONE) && !defined(LIBXSMM_INTRINSICS_DEBUG)
# include <immintrin.h>
# endif /*defined(LIBXSMM_INTRINSICS_INCLUDE)*/
# if !defined(LIBXSMM_INTRINSICS)
# if (LIBXSMM_MAX_STATIC_TARGET_ARCH > LIBXSMM_STATIC_TARGET_ARCH)
# define LIBXSMM_INTRINSICS(TARGET) LIBXSMM_ATTRIBUTE(LIBXSMM_ATTRIBUTE_TARGET(TARGET))
/* LIBXSMM_ATTRIBUTE_TARGET_xxx is required to literally match the CPUID (libxsmm_cpuid.h)! */
# define LIBXSMM_ATTRIBUTE_TARGET_1002 target("sse2") /* LIBXSMM_X86_GENERIC (64-bit ABI) */
# if (LIBXSMM_X86_SSE3 <= LIBXSMM_MAX_STATIC_TARGET_ARCH)
# define LIBXSMM_ATTRIBUTE_TARGET_1003 target("sse3")
# else
# define LIBXSMM_ATTRIBUTE_TARGET_1003 LIBXSMM_ATTRIBUTE_TARGET_1002
# endif
# if (LIBXSMM_X86_SSE42 <= LIBXSMM_MAX_STATIC_TARGET_ARCH)
# define LIBXSMM_ATTRIBUTE_TARGET_1004 target("sse4.1,sse4.2")
# else
# define LIBXSMM_ATTRIBUTE_TARGET_1004 LIBXSMM_ATTRIBUTE_TARGET_1003
# endif
# if (LIBXSMM_X86_AVX <= LIBXSMM_MAX_STATIC_TARGET_ARCH)
# define LIBXSMM_ATTRIBUTE_TARGET_1005 target("avx")
# else
# define LIBXSMM_ATTRIBUTE_TARGET_1005 LIBXSMM_ATTRIBUTE_TARGET_1004
# endif
# if (LIBXSMM_X86_AVX2 <= LIBXSMM_MAX_STATIC_TARGET_ARCH)
# define LIBXSMM_ATTRIBUTE_TARGET_1006 target("avx2,fma")
# else
# define LIBXSMM_ATTRIBUTE_TARGET_1006 LIBXSMM_ATTRIBUTE_TARGET_1005
# endif
# if (LIBXSMM_X86_AVX512 <= LIBXSMM_MAX_STATIC_TARGET_ARCH)
# define LIBXSMM_ATTRIBUTE_TARGET_1007 target("avx2,fma,avx512f,avx512cd")
# else
# define LIBXSMM_ATTRIBUTE_TARGET_1007 LIBXSMM_ATTRIBUTE_TARGET_1006
# endif
# if (LIBXSMM_X86_AVX512_MIC <= LIBXSMM_MAX_STATIC_TARGET_ARCH)
# define LIBXSMM_ATTRIBUTE_TARGET_1010 target("avx2,fma,avx512f,avx512cd,avx512pf,avx512er")
# else /* LIBXSMM_X86_AVX512 */
# define LIBXSMM_ATTRIBUTE_TARGET_1010 LIBXSMM_ATTRIBUTE_TARGET_1007
# endif
# if (LIBXSMM_X86_AVX512_KNM <= LIBXSMM_MAX_STATIC_TARGET_ARCH)
# define LIBXSMM_ATTRIBUTE_TARGET_1011 target("avx2,fma,avx512f,avx512cd,avx512pf,avx512er,avx5124vnniw,avx5124fmaps")
# else /* LIBXSMM_X86_AVX512_MIC */
# define LIBXSMM_ATTRIBUTE_TARGET_1011 LIBXSMM_ATTRIBUTE_TARGET_1010
# endif
# if (LIBXSMM_X86_AVX512_CORE <= LIBXSMM_MAX_STATIC_TARGET_ARCH)
# define LIBXSMM_ATTRIBUTE_TARGET_1020 target("avx2,fma,avx512f,avx512cd,avx512dq,avx512bw,avx512vl")
# else /* LIBXSMM_X86_AVX512 */
# define LIBXSMM_ATTRIBUTE_TARGET_1020 LIBXSMM_ATTRIBUTE_TARGET_1007
# endif
# if (LIBXSMM_X86_AVX512_CLX <= LIBXSMM_MAX_STATIC_TARGET_ARCH)
# define LIBXSMM_ATTRIBUTE_TARGET_1021 target("avx2,fma,avx512f,avx512cd,avx512dq,avx512bw,avx512vl,avx512vnni")
# else /* LIBXSMM_X86_AVX512_CORE */
# define LIBXSMM_ATTRIBUTE_TARGET_1021 LIBXSMM_ATTRIBUTE_TARGET_1020
# endif
# if (LIBXSMM_X86_AVX512_CPX <= LIBXSMM_MAX_STATIC_TARGET_ARCH)
# define LIBXSMM_ATTRIBUTE_TARGET_1022 target("avx2,fma,avx512f,avx512cd,avx512dq,avx512bw,avx512vl,avx512vnni,avx512bf16")
# else /* LIBXSMM_X86_AVX512_CORE */
# define LIBXSMM_ATTRIBUTE_TARGET_1022 LIBXSMM_ATTRIBUTE_TARGET_1021
# endif
# else
# define LIBXSMM_INTRINSICS(TARGET)/*no need for target flags*/
# endif
# elif !defined(LIBXSMM_INTRINSICS_TARGET)
# define LIBXSMM_INTRINSICS_TARGET
# endif /*!defined(LIBXSMM_INTRINSICS)*/
# endif /*defined(LIBXSMM_STATIC_TARGET_ARCH)*/
#endif /*!defined(LIBXSMM_INTRINSICS_NONE)*/
#if !defined(LIBXSMM_STATIC_TARGET_ARCH)
# if !defined(LIBXSMM_INTRINSICS_NONE) && !defined(LIBXSMM_INTRINSICS_STATIC)
# define LIBXSMM_INTRINSICS_NONE
# endif
# define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_TARGET_ARCH_GENERIC
#endif
#if !defined(LIBXSMM_MAX_STATIC_TARGET_ARCH)
# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_STATIC_TARGET_ARCH
#elif (LIBXSMM_MAX_STATIC_TARGET_ARCH < LIBXSMM_STATIC_TARGET_ARCH)
# undef LIBXSMM_MAX_STATIC_TARGET_ARCH
# define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_STATIC_TARGET_ARCH
#endif
#if !defined(LIBXSMM_INTRINSICS)
# define LIBXSMM_INTRINSICS(TARGET)
#endif
/** Include basic x86 intrinsics such as __rdtsc. */
#if defined(LIBXSMM_INTRINSICS_INCLUDE) && !defined(LIBXSMM_INTRINSICS_DEBUG)
# if defined(_WIN32)
# include <intrin.h>
# elif defined(LIBXSMM_INTEL_COMPILER) || defined(_CRAYC) || defined(__clang__) || defined(__PGI)
# include <x86intrin.h>
# elif defined(__GNUC__) && (LIBXSMM_VERSION2(4, 4) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__))
# include <x86intrin.h>
# endif
# include <xmmintrin.h>
# if defined(__SSE3__)
# include <pmmintrin.h>
# endif
#endif
#if !defined(LIBXSMM_INTRINSICS_NONE)
# if defined(_WIN32)
# include <malloc.h>
# else
# include <mm_malloc.h>
# endif
#endif
/**
* Intrinsic-specific fix-ups
*/
# define LIBXSMM_INTRINSICS_LOADU_SI128(A) _mm_loadu_si128(A)
#if !defined(LIBXSMM_INTEL_COMPILER) && defined(__clang__) && ( \
(LIBXSMM_VERSION2(3, 9) > LIBXSMM_VERSION2(__clang_major__, __clang_minor__)) \
|| (LIBXSMM_VERSION2(7, 3) > LIBXSMM_VERSION2(__clang_major__, __clang_minor__) && \
defined(__APPLE__) && defined(__MACH__)))
/* prototypes with incorrect signature: _mm512_load_ps takes DP*, _mm512_load_pd takes SP* (checked with v3.8.1) */
# define LIBXSMM_INTRINSICS_MM512_LOAD_PS(A) _mm512_loadu_ps((const double*)(A))
# define LIBXSMM_INTRINSICS_MM512_LOAD_PD(A) _mm512_loadu_pd((const float*)(A))
/* Clang misses _mm512_stream_p? (checked with v3.8.1). */
# define LIBXSMM_INTRINSICS_MM512_STREAM_SI512(A, B) _mm512_store_si512(A, B)
# define LIBXSMM_INTRINSICS_MM512_STREAM_PS(A, B) _mm512_storeu_ps(A, B)
# define LIBXSMM_INTRINSICS_MM512_STREAM_PD(A, B) _mm512_store_pd(A, B)
#else
# define LIBXSMM_INTRINSICS_MM512_LOAD_PS(A) _mm512_loadu_ps((const float*)(A))
# define LIBXSMM_INTRINSICS_MM512_LOAD_PD(A) _mm512_loadu_pd((const double*)(A))
# define LIBXSMM_INTRINSICS_MM512_STREAM_SI512(A, B) _mm512_stream_si512((__m512i*)(A), (B))
# define LIBXSMM_INTRINSICS_MM512_STREAM_PS(A, B) _mm512_stream_ps(A, B)
# define LIBXSMM_INTRINSICS_MM512_STREAM_PD(A, B) _mm512_stream_pd(A, B)
#endif
#if !defined(LIBXSMM_INTEL_COMPILER) || (defined(__clang__) && ( \
(LIBXSMM_VERSION2(8, 0) > LIBXSMM_VERSION2(__clang_major__, __clang_minor__)))) \
|| (defined(__APPLE__) && defined(__MACH__)) || defined(__GNUC__)
# define LIBXSMM_INTRINSICS_MM256_STORE_EPI32(A, B) _mm256_storeu_si256((__m256i*)(A), B)
#else
# define LIBXSMM_INTRINSICS_MM256_STORE_EPI32(A, B) _mm256_storeu_epi32(A, B)
#endif
#if defined(LIBXSMM_INTEL_COMPILER)
# if 1600 <= (LIBXSMM_INTEL_COMPILER)
# define LIBXSMM_INTRINSICS_MM512_SET_EPI16(E31, E30, E29, E28, E27, E26, E25, E24, E23, E22, E21, E20, E19, E18, E17, E16, \
E15, E14, E13, E12, E11, E10, E9, E8, E7, E6, E5, E4, E3, E2, E1, E0) \
_mm512_set_epi16(E31, E30, E29, E28, E27, E26, E25, E24, E23, E22, E21, E20, E19, E18, E17, E16, \
E15, E14, E13, E12, E11, E10, E9, E8, E7, E6, E5, E4, E3, E2, E1, E0)
# else
# define LIBXSMM_INTRINSICS_MM512_SET_EPI16(E31, E30, E29, E28, E27, E26, E25, E24, E23, E22, E21, E20, E19, E18, E17, E16, \
E15, E14, E13, E12, E11, E10, E9, E8, E7, E6, E5, E4, E3, E2, E1, E0) \
_mm512_castps_si512(_mm512_set_epi16(E31, E30, E29, E28, E27, E26, E25, E24, E23, E22, E21, E20, E19, E18, E17, E16, \
E15, E14, E13, E12, E11, E10, E9, E8, E7, E6, E5, E4, E3, E2, E1, E0))
# endif
#else
# define LIBXSMM_INTRINSICS_MM512_SET_EPI16(E31, E30, E29, E28, E27, E26, E25, E24, E23, E22, E21, E20, E19, E18, E17, E16, \
E15, E14, E13, E12, E11, E10, E9, E8, E7, E6, E5, E4, E3, E2, E1, E0) \
_mm512_set_epi32(((E31) << 16) | (E30), ((E29) << 16) | (E28), ((E27) << 16) | (E26), ((E25) << 16) | (E24), \
((E23) << 16) | (E22), ((E21) << 16) | (E20), ((E19) << 16) | (E18), ((E17) << 16) | (E16), \
((E15) << 16) | (E14), ((E13) << 16) | (E12), ((E11) << 16) | (E10), ((E9) << 16) | (E8), \
((E7) << 16) | (E6), ((E5) << 16) | (E4), ((E3) << 16) | (E2), ((E1) << 16) | (E0))
#endif
#if defined(LIBXSMM_INTEL_COMPILER) \
|| (defined(__GNUC__) && LIBXSMM_VERSION2(7, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)) \
|| (defined(__clang__) && (!defined(__APPLE__) || !defined(__MACH__)) \
&& LIBXSMM_VERSION2(4, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))
# define LIBXSMM_INTRINSICS_MM512_MASK_I32GATHER_EPI32(A, B, C, D, E) _mm512_mask_i32gather_epi32(A, B, C, D, E)
# define LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(A, B) _mm512_extracti64x4_epi64(A, B)
# define LIBXSMM_INTRINSICS_MM512_ABS_PS(A) _mm512_abs_ps(A)
# define LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32() _mm512_undefined_epi32()
# define LIBXSMM_INTRINSICS_MM512_UNDEFINED() _mm512_undefined()
# define LIBXSMM_INTRINSICS_MM256_UNDEFINED_SI256() _mm256_undefined_si256()
# define LIBXSMM_INTRINSICS_MM_UNDEFINED_SI128() _mm_undefined_si128()
# define LIBXSMM_INTRINSICS_MM_UNDEFINED_PD() _mm_undefined_pd()
#else
# define LIBXSMM_INTRINSICS_MM512_MASK_I32GATHER_EPI32(A, B, C, D, E) _mm512_castps_si512(_mm512_mask_i32gather_ps( \
_mm512_castsi512_ps(A), B, C, (const float*)(D), E))
# define LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(A, B) _mm256_castpd_si256(_mm512_extractf64x4_pd(_mm512_castsi512_pd(A), B))
# define LIBXSMM_INTRINSICS_MM512_ABS_PS(A) _mm512_castsi512_ps(_mm512_and_epi32( \
_mm512_castps_si512(A), _mm512_set1_epi32(0x7FFFFFFF)))
# define LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32() _mm512_set1_epi32(0)
# define LIBXSMM_INTRINSICS_MM512_UNDEFINED() _mm512_set1_ps(0)
# define LIBXSMM_INTRINSICS_MM256_UNDEFINED_SI256() _mm256_set1_epi32(0)
# define LIBXSMM_INTRINSICS_MM_UNDEFINED_SI128() _mm_set1_epi32(0)
# define LIBXSMM_INTRINSICS_MM_UNDEFINED_PD() _mm_set1_pd(0)
#endif
#if (defined(LIBXSMM_INTEL_COMPILER) && (1800 <= (LIBXSMM_INTEL_COMPILER))) \
|| (!defined(LIBXSMM_INTEL_COMPILER) && defined(__GNUC__) \
&& LIBXSMM_VERSION2(7, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)) \
|| ((!defined(__APPLE__) || !defined(__MACH__)) && defined(__clang__) \
&& LIBXSMM_VERSION2(8, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))
# define LIBXSMM_INTRINSICS_MM512_STORE_MASK(DST_PTR, SRC, NBITS) \
LIBXSMM_CONCATENATE(_store_mask, NBITS)((LIBXSMM_CONCATENATE(__mmask, NBITS)*)(DST_PTR), SRC)
# define LIBXSMM_INTRINSICS_MM512_LOAD_MASK(SRC_PTR, NBITS) \
LIBXSMM_CONCATENATE(_load_mask, NBITS)((/*const*/ LIBXSMM_CONCATENATE(__mmask, NBITS)*)(SRC_PTR))
# define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK(A, NBITS) LIBXSMM_CONCATENATE(_cvtu32_mask, NBITS)((unsigned int)(A))
#elif defined(LIBXSMM_INTEL_COMPILER)
# define LIBXSMM_INTRINSICS_MM512_STORE_MASK(DST_PTR, SRC, NBITS) \
(*(LIBXSMM_CONCATENATE(__mmask, NBITS)*)(DST_PTR) = (LIBXSMM_CONCATENATE(__mmask, NBITS))(SRC))
# define LIBXSMM_INTRINSICS_MM512_LOAD_MASK(SRC_PTR, NBITS) \
((LIBXSMM_CONCATENATE(__mmask, NBITS))_mm512_mask2int(*(const __mmask16*)(SRC_PTR)))
# define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK(A, NBITS) LIBXSMM_CONCATENATE(LIBXSMM_INTRINSICS_MM512_CVTU32_MASK_, NBITS)(A)
# define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK_32(A) ((__mmask32)(A))
# define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK_16(A) _mm512_int2mask((int)(A))
# define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK_8(A) ((__mmask8)(A))
#else
# define LIBXSMM_INTRINSICS_MM512_STORE_MASK(DST_PTR, SRC, NBITS) \
(*(LIBXSMM_CONCATENATE(__mmask, NBITS)*)(DST_PTR) = (LIBXSMM_CONCATENATE(__mmask, NBITS))(SRC))
# define LIBXSMM_INTRINSICS_MM512_LOAD_MASK(SRC_PTR, NBITS) (*(const LIBXSMM_CONCATENATE(__mmask, NBITS)*)(SRC_PTR))
# define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK(A, NBITS) ((LIBXSMM_CONCATENATE(__mmask, NBITS))(A))
#endif
#define LIBXSMM_INTRINSICS_MM512_STORE_MASK64(DST_PTR, SRC) LIBXSMM_INTRINSICS_MM512_STORE_MASK(DST_PTR, SRC, 64)
#define LIBXSMM_INTRINSICS_MM512_STORE_MASK32(DST_PTR, SRC) LIBXSMM_INTRINSICS_MM512_STORE_MASK(DST_PTR, SRC, 32)
#define LIBXSMM_INTRINSICS_MM512_STORE_MASK16(DST_PTR, SRC) LIBXSMM_INTRINSICS_MM512_STORE_MASK(DST_PTR, SRC, 16)
#define LIBXSMM_INTRINSICS_MM512_STORE_MASK8(DST_PTR, SRC) LIBXSMM_INTRINSICS_MM512_STORE_MASK(DST_PTR, SRC, 8)
#define LIBXSMM_INTRINSICS_MM512_LOAD_MASK64(SRC_PTR) LIBXSMM_INTRINSICS_MM512_LOAD_MASK(SRC_PTR, 64)
#define LIBXSMM_INTRINSICS_MM512_LOAD_MASK32(SRC_PTR) LIBXSMM_INTRINSICS_MM512_LOAD_MASK(SRC_PTR, 32)
#define LIBXSMM_INTRINSICS_MM512_LOAD_MASK16(SRC_PTR) LIBXSMM_INTRINSICS_MM512_LOAD_MASK(SRC_PTR, 16)
#define LIBXSMM_INTRINSICS_MM512_LOAD_MASK8(SRC_PTR) LIBXSMM_INTRINSICS_MM512_LOAD_MASK(SRC_PTR, 8)
#define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK32(A) LIBXSMM_INTRINSICS_MM512_CVTU32_MASK(A, 32)
#define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK16(A) LIBXSMM_INTRINSICS_MM512_CVTU32_MASK(A, 16)
#define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK8(A) LIBXSMM_INTRINSICS_MM512_CVTU32_MASK(A, 8)
/**
* Pseudo intrinsics for portability
*/
LIBXSMM_API_INLINE int LIBXSMM_INTRINSICS_BITSCANFWD32_SW(unsigned int n) {
unsigned int i, r = 0; if (0 != n) for (i = 1; 0 == (n & i); i <<= 1) { ++r; } return r;
}
LIBXSMM_API_INLINE int LIBXSMM_INTRINSICS_BITSCANFWD64_SW(unsigned long long n) {
unsigned int i, r = 0; if (0 != n) for (i = 1; 0 == (n & i); i <<= 1) { ++r; } return r;
}
/** Binary Logarithm (based on Stackoverflow's NBITSx macro). */
#define LIBXSMM_INTRINSICS_BITSCANBWD_SW02(N) (0 != ((N) & 0x2/*0b10*/) ? 1 : 0)
#define LIBXSMM_INTRINSICS_BITSCANBWD_SW04(N) (0 != ((N) & 0xC/*0b1100*/) ? (2 | LIBXSMM_INTRINSICS_BITSCANBWD_SW02((N) >> 2)) : LIBXSMM_INTRINSICS_BITSCANBWD_SW02(N))
#define LIBXSMM_INTRINSICS_BITSCANBWD_SW08(N) (0 != ((N) & 0xF0/*0b11110000*/) ? (4 | LIBXSMM_INTRINSICS_BITSCANBWD_SW04((N) >> 4)) : LIBXSMM_INTRINSICS_BITSCANBWD_SW04(N))
#define LIBXSMM_INTRINSICS_BITSCANBWD_SW16(N) (0 != ((N) & 0xFF00) ? (8 | LIBXSMM_INTRINSICS_BITSCANBWD_SW08((N) >> 8)) : LIBXSMM_INTRINSICS_BITSCANBWD_SW08(N))
#define LIBXSMM_INTRINSICS_BITSCANBWD_SW32(N) (0 != ((N) & 0xFFFF0000) ? (16 | LIBXSMM_INTRINSICS_BITSCANBWD_SW16((N) >> 16)) : LIBXSMM_INTRINSICS_BITSCANBWD_SW16(N))
#define LIBXSMM_INTRINSICS_BITSCANBWD_SW64(N) (0 != ((N) & 0xFFFFFFFF00000000) ? (32 | LIBXSMM_INTRINSICS_BITSCANBWD_SW32((N) >> 32)) : LIBXSMM_INTRINSICS_BITSCANBWD_SW32(N))
#define LIBXSMM_INTRINSICS_BITSCANBWD32_SW(N) LIBXSMM_INTRINSICS_BITSCANBWD_SW32((unsigned int)(N))
#define LIBXSMM_INTRINSICS_BITSCANBWD64_SW(N) LIBXSMM_INTRINSICS_BITSCANBWD_SW64((unsigned long long)(N))
#if defined(_WIN32) && !defined(LIBXSMM_INTRINSICS_NONE)
LIBXSMM_API_INLINE unsigned int LIBXSMM_INTRINSICS_BITSCANFWD32(unsigned int n) {
unsigned long r = 0; _BitScanForward(&r, n); return (0 != n) * r;
}
LIBXSMM_API_INLINE unsigned int LIBXSMM_INTRINSICS_BITSCANBWD32(unsigned int n) {
unsigned long r = 0; _BitScanReverse(&r, n); return r;
}
# if defined(_WIN64)
LIBXSMM_API_INLINE unsigned int LIBXSMM_INTRINSICS_BITSCANFWD64(unsigned long long n) {
unsigned long r = 0; _BitScanForward64(&r, n); return (0 != n) * r;
}
LIBXSMM_API_INLINE unsigned int LIBXSMM_INTRINSICS_BITSCANBWD64(unsigned long long n) {
unsigned long r = 0; _BitScanReverse64(&r, n); return r;
}
# else
# define LIBXSMM_INTRINSICS_BITSCANFWD64 LIBXSMM_INTRINSICS_BITSCANFWD64_SW
# define LIBXSMM_INTRINSICS_BITSCANBWD64 LIBXSMM_INTRINSICS_BITSCANBWD64_SW
# endif
#elif defined(__GNUC__) && !defined(LIBXSMM_INTRINSICS_NONE)
# define LIBXSMM_INTRINSICS_BITSCANFWD32(N) ((0 != (N)) * __builtin_ctz(N))
# define LIBXSMM_INTRINSICS_BITSCANFWD64(N) ((0 != (N)) * __builtin_ctzll(N))
# define LIBXSMM_INTRINSICS_BITSCANBWD32(N) ((0 != (N)) * (31 - __builtin_clz(N)))
# define LIBXSMM_INTRINSICS_BITSCANBWD64(N) ((0 != (N)) * (63 - __builtin_clzll(N)))
#else /* fallback implementation */
# define LIBXSMM_INTRINSICS_BITSCANFWD32 LIBXSMM_INTRINSICS_BITSCANFWD32_SW
# define LIBXSMM_INTRINSICS_BITSCANFWD64 LIBXSMM_INTRINSICS_BITSCANFWD64_SW
# define LIBXSMM_INTRINSICS_BITSCANBWD32 LIBXSMM_INTRINSICS_BITSCANBWD32_SW
# define LIBXSMM_INTRINSICS_BITSCANBWD64 LIBXSMM_INTRINSICS_BITSCANBWD64_SW
#endif
/** LIBXSMM_NBITS determines the minimum number of bits needed to represent N. */
#define LIBXSMM_NBITS(N) (LIBXSMM_INTRINSICS_BITSCANBWD64(N) + LIBXSMM_MIN(1, N))
#define LIBXSMM_ISQRT2(N) ((unsigned int)((1ULL << (LIBXSMM_NBITS(N) >> 1)) /*+ LIBXSMM_MIN(1, N)*/))
/** LIBXSMM_ILOG2 definition matches ceil(log2(N)). */
LIBXSMM_API_INLINE unsigned int LIBXSMM_ILOG2(unsigned long long n) {
unsigned int result = 0; if (1 < n) {
const unsigned int m = LIBXSMM_INTRINSICS_BITSCANBWD64(n);
result = m + ((unsigned int)LIBXSMM_INTRINSICS_BITSCANBWD64(n - 1) == m);
} return result;
}
/**
* Target attribution
*/
#if !defined(LIBXSMM_INTRINSICS_KNC) && !defined(LIBXSMM_INTRINSICS_NONE) && defined(__MIC__)
# define LIBXSMM_INTRINSICS_KNC
#endif
/** LIBXSMM_INTRINSICS_X86 is defined only if the compiler is able to generate this code without special flags. */
#if !defined(LIBXSMM_INTRINSICS_X86) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_GENERIC <= LIBXSMM_STATIC_TARGET_ARCH || \
(!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_GENERIC <= LIBXSMM_MAX_STATIC_TARGET_ARCH))
# define LIBXSMM_INTRINSICS_X86
#endif
/** LIBXSMM_INTRINSICS_SSE3 is defined only if the compiler is able to generate this code without special flags. */
#if !defined(LIBXSMM_INTRINSICS_SSE3) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_SSE3 <= LIBXSMM_STATIC_TARGET_ARCH || \
(!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_SSE3 <= LIBXSMM_MAX_STATIC_TARGET_ARCH))
# define LIBXSMM_INTRINSICS_SSE3
#endif
/** LIBXSMM_INTRINSICS_SSE42 is defined only if the compiler is able to generate this code without special flags. */
#if !defined(LIBXSMM_INTRINSICS_SSE42) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_SSE42 <= LIBXSMM_STATIC_TARGET_ARCH || \
(!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_SSE42 <= LIBXSMM_MAX_STATIC_TARGET_ARCH))
# define LIBXSMM_INTRINSICS_SSE42
#endif
/** LIBXSMM_INTRINSICS_AVX is defined only if the compiler is able to generate this code without special flags. */
#if !defined(LIBXSMM_INTRINSICS_AVX) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_AVX <= LIBXSMM_STATIC_TARGET_ARCH || \
(!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_AVX <= LIBXSMM_MAX_STATIC_TARGET_ARCH))
# define LIBXSMM_INTRINSICS_AVX
#endif
/** LIBXSMM_INTRINSICS_AVX2 is defined only if the compiler is able to generate this code without special flags. */
#if !defined(LIBXSMM_INTRINSICS_AVX2) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_AVX2 <= LIBXSMM_STATIC_TARGET_ARCH || \
(!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_AVX2 <= LIBXSMM_MAX_STATIC_TARGET_ARCH))
# define LIBXSMM_INTRINSICS_AVX2
#endif
/** LIBXSMM_INTRINSICS_AVX512 is defined only if the compiler is able to generate this code without special flags. */
#if !defined(LIBXSMM_INTRINSICS_AVX512) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH || \
(!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_AVX512 <= LIBXSMM_MAX_STATIC_TARGET_ARCH))
# define LIBXSMM_INTRINSICS_AVX512
#endif
/** LIBXSMM_INTRINSICS_AVX512_MIC is defined only if the compiler is able to generate this code without special flags. */
#if !defined(LIBXSMM_INTRINSICS_AVX512_MIC) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_AVX512_MIC <= LIBXSMM_STATIC_TARGET_ARCH || \
(!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_AVX512_MIC <= LIBXSMM_MAX_STATIC_TARGET_ARCH))
# define LIBXSMM_INTRINSICS_AVX512_MIC
#endif
/** LIBXSMM_INTRINSICS_AVX512_KNM is defined only if the compiler is able to generate this code without special flags. */
#if !defined(LIBXSMM_INTRINSICS_AVX512_KNM) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_AVX512_KNM <= LIBXSMM_STATIC_TARGET_ARCH || \
(!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_AVX512_KNM <= LIBXSMM_MAX_STATIC_TARGET_ARCH))
# define LIBXSMM_INTRINSICS_AVX512_KNM
#endif
/** LIBXSMM_INTRINSICS_AVX512_CORE is defined only if the compiler is able to generate this code without special flags. */
#if !defined(LIBXSMM_INTRINSICS_AVX512_CORE) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_AVX512_CORE <= LIBXSMM_STATIC_TARGET_ARCH || \
(!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_AVX512_CORE <= LIBXSMM_MAX_STATIC_TARGET_ARCH))
# define LIBXSMM_INTRINSICS_AVX512_CORE
#endif
/** LIBXSMM_INTRINSICS_AVX512_CLX is defined only if the compiler is able to generate this code without special flags. */
#if !defined(LIBXSMM_INTRINSICS_AVX512_CLX) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_AVX512_CLX <= LIBXSMM_STATIC_TARGET_ARCH || \
(!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_AVX512_CLX <= LIBXSMM_MAX_STATIC_TARGET_ARCH))
# define LIBXSMM_INTRINSICS_AVX512_CLX
#endif
/** LIBXSMM_INTRINSICS_AVX512_CPX is defined only if the compiler is able to generate this code without special flags. */
#if !defined(LIBXSMM_INTRINSICS_AVX512_CPX) && !defined(LIBXSMM_INTRINSICS_NONE) && defined(LIBXSMM_X86_AVX512_CPX) && \
!defined(LIBXSMM_INTRINSICS_STATIC) && (LIBXSMM_X86_AVX512_CPX <= LIBXSMM_MAX_STATIC_TARGET_ARCH)
# define LIBXSMM_INTRINSICS_AVX512_CPX
#endif
/** 2048-bit state for xoshiro128+ RNG (state/symbols needed even if AVX-512 is not used) */
#define LIBXSMM_INTRINSICS_MM512_RNG_STATE(INDEX) (*(__m512i*)LIBXSMM_CONCATENATE(libxsmm_intrinsics_mm512_rng_state, INDEX))
LIBXSMM_APIVAR_PUBLIC(unsigned int libxsmm_intrinsics_mm512_rng_state0[16]);
LIBXSMM_APIVAR_PUBLIC(unsigned int libxsmm_intrinsics_mm512_rng_state1[16]);
LIBXSMM_APIVAR_PUBLIC(unsigned int libxsmm_intrinsics_mm512_rng_state2[16]);
LIBXSMM_APIVAR_PUBLIC(unsigned int libxsmm_intrinsics_mm512_rng_state3[16]);
/**
* Pseudo intrinsics (AVX-2)
*/
#if defined(LIBXSMM_INTRINSICS_AVX2) /*__AVX2__*/
# if defined(__GNUC__) && !defined(__clang__) && !defined(LIBXSMM_INTEL_COMPILER) && !defined(_CRAYC) && 0
LIBXSMM_PRAGMA_OPTIMIZE_OFF /* avoid ICE in case of symbols (-g) */
# endif
/** Generate random number in the interval [0, 1); thread save, state needs to be managed by user.
* this is based on xoshiro128+ 1.0, e.g. http://prng.di.unimi.it/xoshiro128plus.c */
LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX2) __m256i LIBXSMM_INTRINSICS_MM256_RNG_XOSHIRO128P_EXTSTATE_EPI32(unsigned int* stateptr) {
__m256i state_0 = _mm256_loadu_si256( (const __m256i*)stateptr );
__m256i state_1 = _mm256_loadu_si256( (const __m256i*)(stateptr+16) );
__m256i state_2 = _mm256_loadu_si256( (const __m256i*)(stateptr+32) );
__m256i state_3 = _mm256_loadu_si256( (const __m256i*)(stateptr+48) );
const __m256i result = _mm256_add_epi32(state_0, state_3);
const __m256i s = _mm256_slli_epi32(state_1, 9);
__m256i t;
state_2 = _mm256_xor_si256(state_2, state_0);
state_3 = _mm256_xor_si256(state_3, state_1);
state_1 = _mm256_xor_si256(state_1, state_2);
state_0 = _mm256_xor_si256(state_0, state_3);
state_2 = _mm256_xor_si256(state_2, s);
_mm256_storeu_si256( (__m256i*)stateptr , state_0 );
_mm256_storeu_si256( (__m256i*)(stateptr+16), state_1 );
_mm256_storeu_si256( (__m256i*)(stateptr+32), state_2 );
t = _mm256_slli_epi32(state_3, 11);
state_3 = _mm256_or_si256(t, _mm256_srli_epi32(state_3, 32 - 11));
_mm256_storeu_si256( (__m256i*)(stateptr+48), state_3 );
return result;
}
LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX2) __m256 LIBXSMM_INTRINSICS_MM256_RNG_EXTSTATE_PS(unsigned int* stateptr) {
const __m256i rng_mantissa = _mm256_srli_epi32( LIBXSMM_INTRINSICS_MM256_RNG_XOSHIRO128P_EXTSTATE_EPI32(stateptr), 9 );
const __m256 one = _mm256_set1_ps(1.0f);
return _mm256_sub_ps(_mm256_castsi256_ps(_mm256_or_si256(_mm256_set1_epi32(0x3f800000), rng_mantissa)), one);
}
# if defined(__GNUC__) && !defined(__clang__) && !defined(LIBXSMM_INTEL_COMPILER) && !defined(_CRAYC) && 0
LIBXSMM_PRAGMA_OPTIMIZE_ON
# endif
#endif /*__AVX2__*/
/**
* Pseudo intrinsics (AVX-512)
*/
#if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
# define LIBXSMM_INTRINSICS_MM512_QUANTIZE_NEAR_PS_EPI16( A, B ) _mm512_cvtepi32_epi16(_mm512_cvt_roundps_epi32( \
_mm512_mul_ps(LIBXSMM_INTRINSICS_MM512_LOAD_PS(A), B), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))
LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512i LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(__m512 a) {
const __m512i vnaninf = _mm512_set1_epi32(0x7f800000), vrneadd = _mm512_set1_epi32(0x00007fff);
const __m512i vfixup = _mm512_set1_epi32(0x00000001), vfixupmask = _mm512_set1_epi32(0x00010000);
const __m512i mm512_roundbf16rne_a_ = _mm512_castps_si512(a);
const __mmask16 mm512_roundbf16rne_mask1_ = _mm512_cmp_epi32_mask(_mm512_and_epi32(mm512_roundbf16rne_a_, vnaninf), vnaninf, _MM_CMPINT_NE);
const __mmask16 mm512_roundbf16rne_mask2_ = _mm512_cmp_epi32_mask(_mm512_and_epi32(mm512_roundbf16rne_a_, vfixupmask), vfixupmask, _MM_CMPINT_EQ);
return _mm512_mask_add_epi32(mm512_roundbf16rne_a_, mm512_roundbf16rne_mask1_, mm512_roundbf16rne_a_, _mm512_mask_add_epi32(vrneadd, mm512_roundbf16rne_mask2_, vrneadd, vfixup));
}
LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m256i LIBXSMM_INTRINSICS_MM512_CVT_FP32_BF16(__m512 a) {
return _mm512_cvtepi32_epi16(_mm512_srai_epi32(LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(a), 16));
}
LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512i LIBXSMM_INTRINSICS_MM512_CVT2_FP32_BF16(__m512 a, __m512 b) {
const __m256i aa = _mm512_cvtepi32_epi16(_mm512_srai_epi32(LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(b), 16));
const __m256i bb = _mm512_cvtepi32_epi16(_mm512_srai_epi32(LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(a), 16));
return _mm512_inserti64x4(_mm512_inserti64x4(_mm512_setzero_si512(), aa, 0), bb, 1);
}
LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(__m256i a) {
return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(a),16));
}
/** SVML-intrinsics */
LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_TANH_PS_RATIONAL_78(__m512 x) {
const __m512 c0 = _mm512_set1_ps(2027025.0f);
const __m512 c1 = _mm512_set1_ps(270270.0f);
const __m512 c2 = _mm512_set1_ps(6930.0f);
const __m512 c3 = _mm512_set1_ps(36.0f);
const __m512 c1_d = _mm512_set1_ps(945945.0f);
const __m512 c2_d = _mm512_set1_ps(51975.0f);
const __m512 c3_d = _mm512_set1_ps(630.0f);
const __m512 hi_bound = _mm512_set1_ps(4.97f);
const __m512 lo_bound = _mm512_set1_ps(-4.97f);
const __m512 ones = _mm512_set1_ps(1.0f);
const __m512 neg_ones = _mm512_set1_ps(-1.0f);
const __m512 x2 = _mm512_mul_ps( x, x );
const __m512 t1_nom = _mm512_fmadd_ps( c3, x2, c2 );
const __m512 t2_nom = _mm512_fmadd_ps( t1_nom, x2, c1 );
const __m512 t3_nom = _mm512_fmadd_ps( t2_nom, x2, c0 );
const __m512 nom = _mm512_mul_ps( t3_nom, x );
const __m512 t1_denom = _mm512_add_ps( x2, c3_d );
const __m512 t2_denom = _mm512_fmadd_ps( t1_denom, x2, c2_d );
const __m512 t3_denom = _mm512_fmadd_ps( t2_denom, x2, c1_d );
const __m512 denom = _mm512_fmadd_ps( t3_denom, x2, c0 );
const __m512 denom_rcp = _mm512_rcp14_ps( denom );
const __mmask16 mask_hi = _mm512_cmp_ps_mask( x, hi_bound, _CMP_GT_OQ);
const __mmask16 mask_lo = _mm512_cmp_ps_mask( x, lo_bound, _CMP_LT_OQ);
__m512 result = _mm512_mul_ps( nom, denom_rcp );
result = _mm512_mask_blend_ps(mask_hi, result, ones);
result = _mm512_mask_blend_ps(mask_lo, result, neg_ones);
return result;
}
LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_TANH_PS_RATIONAL_32(__m512 x) {
const __m512 c1 = _mm512_set1_ps((float)(1.0/27.0));
const __m512 c2 = _mm512_set1_ps((float)(1.0/3));
const __m512 hi_bound = _mm512_set1_ps(3.2f);
const __m512 lo_bound = _mm512_set1_ps(-3.2f);
const __m512 ones = _mm512_set1_ps(1.0f);
const __m512 neg_ones = _mm512_set1_ps(-1.0f);
const __m512 x2 = _mm512_mul_ps( x, x );
const __m512 t1_nom = _mm512_fmadd_ps( x2, c1, ones);
const __m512 nom = _mm512_mul_ps( t1_nom, x );
const __m512 denom = _mm512_fmadd_ps( x2, c2, ones);
const __m512 denom_rcp = _mm512_rcp14_ps( denom );
const __mmask16 mask_hi = _mm512_cmp_ps_mask( x, hi_bound, _CMP_GT_OQ);
const __mmask16 mask_lo = _mm512_cmp_ps_mask( x, lo_bound, _CMP_LT_OQ);
__m512 result = _mm512_mul_ps(nom, denom_rcp);
result = _mm512_mask_blend_ps(mask_hi, result, ones);
result = _mm512_mask_blend_ps(mask_lo, result, neg_ones);
return result;
}
LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_TANH_PS_EXP2(__m512 _x) {
const __m512 twice_log2_e = _mm512_set1_ps((float)(1.442695*2));
const __m512 half = _mm512_set1_ps(0.5f);
const __m512 c2 = _mm512_set1_ps(0.240226507f);
const __m512 c1 = _mm512_set1_ps(0.452920674f);
const __m512 c0 = _mm512_set1_ps(0.713483036f);
const __m512 ones = _mm512_set1_ps(1.0f);
const __m512 minus_twos = _mm512_set1_ps(-2.0f);
const __m512 x = _mm512_fmadd_ps(_x, twice_log2_e, half);
#if 1
const __m512 y = _mm512_sub_ps(x, _mm512_roundscale_round_ps(x, 1, _MM_FROUND_CUR_DIRECTION));
#else
const __m512 y = _mm512_reduce_ps(x, 1);
#endif
const __m512 t1 = _mm512_fmadd_ps( y, c2, c1);
const __m512 two_to_y = _mm512_fmadd_ps( y, t1, c0);
const __m512 exp = _mm512_scalef_ps( two_to_y, x );
const __m512 denom_rcp = _mm512_rcp14_ps( _mm512_add_ps( exp, ones) );
__m512 result = _mm512_fmadd_ps( denom_rcp, minus_twos, ones);
return result;
}
LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_TANH_PS_EXP3(__m512 _x) {
const __m512 twice_log2_e = _mm512_set1_ps((float)(1.442695*2));
const __m512 half = _mm512_set1_ps(0.5f);
const __m512 c3 = _mm512_set1_ps(0.05550410866f);
const __m512 c2 = _mm512_set1_ps(0.15697034396f);
const __m512 c1 = _mm512_set1_ps(0.49454875509f);
const __m512 c0 = _mm512_set1_ps(0.70654502287f);
const __m512 ones = _mm512_set1_ps(1.0f);
const __m512 minus_twos = _mm512_set1_ps(-2.0f);
const __m512 x = _mm512_fmadd_ps(_x, twice_log2_e, half);
#if 1
const __m512 y = _mm512_sub_ps(x, _mm512_roundscale_round_ps(x, 1, _MM_FROUND_CUR_DIRECTION));
#else
const __m512 y = _mm512_reduce_ps(x, 1);
#endif
const __m512 t1 = _mm512_fmadd_ps( y, c3, c2);
const __m512 t2 = _mm512_fmadd_ps( y, t1, c1);
const __m512 two_to_y = _mm512_fmadd_ps( y, t2, c0);
const __m512 exp = _mm512_scalef_ps( two_to_y, x );
const __m512 denom_rcp = _mm512_rcp14_ps( _mm512_add_ps( exp, ones) );
__m512 result = _mm512_fmadd_ps( denom_rcp, minus_twos, ones);
return result;
}
LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2(__m512 x) {
__m512 result, func_p0, func_p1, func_p2;
const __m512i sign_mask = _mm512_set1_epi32( 0x80000000 );
const __m512i sign_filter = _mm512_set1_epi32( 0x7FFFFFFF );
const __m512i lut_low = _mm512_set1_epi32( 246 );
const __m512i lut_high = _mm512_set1_epi32( 261 );
const __m512 tanh_p0_2_reg = _mm512_set_ps( 0.40555000f, 0.11892800f, -0.00972979f, -0.02740300f, -0.0169851f, -0.00776152f, -0.00305889f,
-0.00116259f, -0.00041726f, -8.53233e-6f, 1.0000000f, 0.99999800f, 0.99975400f, 0.99268200f,
0.93645300f, 0.73833900f);
const __m512 tanh_p1_2_reg = _mm512_set_ps( 0.495602f, 0.88152f, 1.125700000f, 1.17021000f, 1.1289000000f, 1.07929000f, 1.0432300f, 1.023010f,
1.011620f, 1.00164f, 1.56828e-14f, 4.49924e-7f, 0.0000646924f, 0.00260405f, 0.0311608f, 0.168736f);
const __m512 tanh_p2_2_reg = _mm512_set_ps(-0.108238f, -0.2384280f, -0.354418000f, -0.38240300f, -0.34135700f, -0.274509000f, -0.20524900f, -0.1511960f,
-0.107635f, -0.0466868f, -3.60822e-16f, -2.05971e-8f, -4.24538e-6f, -0.000231709f, -0.00386434f, -0.0277702f);
const __m512i signs = _mm512_and_epi32(_mm512_castps_si512(x), sign_mask);
const __m512i abs_arg = _mm512_and_epi32(_mm512_castps_si512(x), sign_filter);
__m512i indices = _mm512_srli_epi32(abs_arg, 22);
indices = _mm512_max_epi32(indices, lut_low);
indices = _mm512_min_epi32(indices, lut_high);
func_p0 = _mm512_permutexvar_ps(indices, tanh_p0_2_reg);
func_p1 = _mm512_permutexvar_ps(indices, tanh_p1_2_reg);
func_p2 = _mm512_permutexvar_ps(indices, tanh_p2_2_reg);
result = _mm512_fmadd_ps(_mm512_castsi512_ps(abs_arg), func_p2, func_p1);
result = _mm512_fmadd_ps(_mm512_castsi512_ps(abs_arg), result, func_p0);
result = _mm512_castsi512_ps(_mm512_xor_epi32(_mm512_castps_si512(result), signs));
return result;
}
LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX3(__m512 x) {
__m512 result, func_p0, func_p1, func_p2, func_p3;
const __m512i sign_mask = _mm512_set1_epi32( 0x80000000 );
const __m512i sign_filter = _mm512_set1_epi32( 0x7FFFFFFF );
const __m512i lut_low = _mm512_set1_epi32( 246 );
const __m512i lut_high = _mm512_set1_epi32( 261 );
const __m512 tanh_p0_3_reg = _mm512_setr_ps( 0.466283000f, 0.82850600f, 0.97437500f, 0.99882600f, 0.9999860f, 1.0000000f, -1.50006e-08f, -7.98169e-06f,
-4.53753e-05f, -0.00023755f, -0.00125285f, -0.00572314f, -0.0227717f, -0.0629089f, -0.084234300f, 0.071199800f);
const __m512 tanh_p1_3_reg = _mm512_setr_ps( 0.500617f, 0.124369f, 0.0137214f, 0.000464124f, 4.02465e-06f, 0.00000f, 1.00001f, 1.00028f, 1.00112f, 1.00414f,
1.015570f, 1.050950f, 1.1478500f, 1.310130000f, 1.378950000f, 1.07407f);
const __m512 tanh_p2_3_reg = _mm512_setr_ps(-0.16133200f, -0.0305526f, -0.00245909f, -6.12647e-05f, -3.76127e-07f, 0.000000f, -0.000245872f, -0.00341151f,
-0.00971505f, -0.0256817f, -0.06869110f, -0.162433000f, -0.346828000f, -0.566516f, -0.640214000f, -0.44011900f);
const __m512 tanh_p3_3_reg = _mm512_setr_ps( 0.0177393f, 0.00253432f, 0.000147303f, 2.69963e-06f, 1.16764e-08f, 0.0000000f, -0.330125f, -0.3176210f,
-0.3017760f, -0.27358000f, -0.219375000f, -0.136197000f, -0.01868680f, 0.0808901f, 0.107095f, 0.0631459f);
const __m512i signs = _mm512_and_epi32(_mm512_castps_si512(x), sign_mask);
const __m512i abs_arg = _mm512_and_epi32(_mm512_castps_si512(x), sign_filter);
__m512i indices = _mm512_srli_epi32(abs_arg, 22);
indices = _mm512_max_epi32(indices, lut_low);
indices = _mm512_min_epi32(indices, lut_high);
func_p0 = _mm512_permutexvar_ps(indices, tanh_p0_3_reg);
func_p1 = _mm512_permutexvar_ps(indices, tanh_p1_3_reg);
func_p2 = _mm512_permutexvar_ps(indices, tanh_p2_3_reg);
func_p3 = _mm512_permutexvar_ps(indices, tanh_p3_3_reg);
result = _mm512_fmadd_ps(_mm512_castsi512_ps(abs_arg), func_p3, func_p2);
result = _mm512_fmadd_ps(_mm512_castsi512_ps(abs_arg), result, func_p1);
result = _mm512_fmadd_ps(_mm512_castsi512_ps(abs_arg), result, func_p0);
result = _mm512_castsi512_ps(_mm512_xor_epi32(_mm512_castps_si512(result), signs));
return result;
}
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512DQ__ needed*/
LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) __m512 LIBXSMM_INTRINSICS_MM512_GELU_FWD_PS_MINIMAX3(__m512 x) {
const __m512 thres = _mm512_castsi512_ps(_mm512_set1_epi32(0x40879fff));
const __m512 absmask = _mm512_castsi512_ps(_mm512_set1_epi32(0x7fffffff));
const __m512 scale = _mm512_castsi512_ps(_mm512_set1_epi32(0x406a0ea1));
const __m512 shifter = _mm512_castsi512_ps(_mm512_set1_epi32(0x4b400000));
const __m512 half = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f000000));
const __m512 _c2 = _mm512_castsi512_ps(_mm512_setr_epi32(0xbd877b85u, 0xbd7d9780u, 0xbd4cb70eu, 0xbd08a1e9u, 0xbc808857u, 0xb9476fd2u, 0x3c36f765u, 0x3c924160u,
0x3ca7b1fcu, 0x3ca5732cu, 0x3c95af63u, 0x3c8079f7u, 0x3c55fa4fu, 0x3c2fa86bu, 0x3c0fbb00u, 0x3bec178cu));
const __m512 _c1 = _mm512_castsi512_ps(_mm512_setr_epi32(0xb7c7fb58u, 0xbacb9740u, 0xbc3e4b3au, 0xbd0d292au, 0xbd8bc5d0u, 0xbdd9978fu, 0xbe0f92d3u, 0xbe27b66du,
0xbe328ce7u, 0xbe3125bfu, 0xbe26dc9du, 0xbe17a056u, 0xbe06bdebu, 0xbdecc593u, 0xbdcf57aau, 0xbdb5ea3au));
const __m512 _c0 = _mm512_castsi512_ps(_mm512_setr_epi32(0x3ecc4231u, 0x3ecc541cu, 0x3ecd6c48u, 0x3ed174c3u, 0x3ed9bd5du, 0x3ee5acd5u, 0x3ef2aeddu, 0x3efd5384u,
0x3f016724u, 0x3f00f778u, 0x3efb389eu, 0x3ef0464du, 0x3ee3014fu, 0x3ed50a78u, 0x3ec779dbu, 0x3ebae363u));
__m512 result;
__m512 xr = _mm512_range_round_ps(x, thres, 2, _MM_FROUND_NO_EXC);
__m512 xa = _mm512_and_ps(xr, absmask);
__m512 index = _mm512_fmadd_ps(xa, scale, shifter);
__m512 c2 = _mm512_permutexvar_ps(_mm512_castps_si512(index), _c2);
__m512 c1 = _mm512_permutexvar_ps(_mm512_castps_si512(index), _c1);
__m512 c0 = _mm512_permutexvar_ps(_mm512_castps_si512(index), _c0);
__m512 poly = _mm512_fmadd_ps(c2, xa, c1);
poly = _mm512_fmadd_ps(poly, xa, c0);
result = _mm512_mul_ps(x, _mm512_fmadd_ps(poly, xr, half));
return result;
}
#endif /*defined(LIBXSMM_INTRINSICS_AVX512_CORE)*/
#if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512DQ__ needed*/
LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE) __m512 LIBXSMM_INTRINSICS_MM512_GELU_BWD_PS_MINIMAX3(__m512 x) {
const __m512 thres = _mm512_castsi512_ps(_mm512_set1_epi32(0x408f5fff));
const __m512 absmask = _mm512_castsi512_ps(_mm512_set1_epi32(0x7fffffff));
const __m512 scale = _mm512_castsi512_ps(_mm512_set1_epi32(0x405d67c9));
const __m512 shifter = _mm512_castsi512_ps(_mm512_set1_epi32(0x4b400000));
const __m512 half = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f000000));
const __m512 _c2 = _mm512_castsi512_ps(_mm512_setr_epi32(0xbe87047bu, 0xbe6eb875u, 0xbe2210c1u, 0xbd81727fu, 0x3cb9625cu, 0x3da2cbe8u, 0x3dd1d4d1u, 0x3dca0bd0u,
0x3da47dd0u, 0x3d6f1bd3u, 0x3d216381u, 0x3cd2618cu, 0x3c89f6e6u, 0x3c3ca672u, 0x3c08ed08u, 0x3bd26a14u));
const __m512 _c1 = _mm512_castsi512_ps(_mm512_setr_epi32(0xb930e738u, 0xbc4b28bau, 0xbda4212fu, 0xbe5feb0eu, 0xbec8b0e5u, 0xbf09e61bu, 0xbf1c403fu, 0xbf185954u,
0xbf03e1eeu, 0xbed08a61u, 0xbe9b4508u, 0xbe61788bu, 0xbe257770u, 0xbdfc542au, 0xbdca014eu, 0xbda8d7e9u));
const __m512 _c0 = _mm512_castsi512_ps(_mm512_setr_epi32(0x3f4c4245u, 0x3f4c927bu, 0x3f5085f8u, 0x3f5d7bdau, 0x3f73ea12u, 0x3f86142fu, 0x3f8d3df4u, 0x3f8b4b0fu,
0x3f8022c8u, 0x3f5e5423u, 0x3f39ceb5u, 0x3f199bedu, 0x3f00bee0u, 0x3ede1737u, 0x3ec59b86u, 0x3eb4454cu));
__m512 result;
__m512 xr = _mm512_range_round_ps(x, thres, 2, _MM_FROUND_NO_EXC);
__m512 xa = _mm512_and_ps(xr, absmask);
__m512 index = _mm512_fmadd_ps(xa, scale, shifter);
__m512 c2 = _mm512_permutexvar_ps(_mm512_castps_si512(index), _c2);
__m512 c1 = _mm512_permutexvar_ps(_mm512_castps_si512(index), _c1);
__m512 c0 = _mm512_permutexvar_ps(_mm512_castps_si512(index), _c0);
__m512 poly = _mm512_fmadd_ps(c2, xa, c1);
poly = _mm512_fmadd_ps(poly, xa, c0);
result = _mm512_fmadd_ps(poly, xr, half);
return result;
}
#endif /*defined(LIBXSMM_INTRINSICS_AVX512_CORE)*/
LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_TANH_PS_GELU_FWD(__m512 x) {
const __m512 c1 = _mm512_set1_ps( (float)0.79788);
const __m512 c2 = _mm512_set1_ps( (float)0.03568);
const __m512 c_half = _mm512_set1_ps( (float)0.5);
__m512 x_half = _mm512_mul_ps( x, c_half );
__m512 x_sq = _mm512_mul_ps( x, x );
__m512 poly_x1 = _mm512_mul_ps(x, _mm512_fmadd_ps( x_sq, c2, c1));
__m512 tanh_poly_x = LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2(poly_x1);
__m512 output = _mm512_fmadd_ps(tanh_poly_x, x_half, x_half);
return output;
}
LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_TANH_PS_GELU_BWD(__m512 x) {
const __m512 c1 = _mm512_set1_ps( (float)0.79788);
const __m512 c2 = _mm512_set1_ps( (float)0.03568);
const __m512 c3 = _mm512_set1_ps( (float)0.05352);
const __m512 c4 = _mm512_set1_ps( (float)0.39894);
const __m512 c_half = _mm512_set1_ps( (float)0.5);
const __m512 c_ones = _mm512_set1_ps( (float)1.0);
const __m512 c_minus_1 = _mm512_set1_ps( (float)-1.0);
__m512 x_sq = _mm512_mul_ps( x, x );
__m512 poly_x1 = _mm512_mul_ps(x, _mm512_fmadd_ps( x_sq, c2, c1));
__m512 poly_x2 = _mm512_mul_ps(x, _mm512_fmadd_ps( x_sq, c3, c4));
__m512 tanh_poly_x = LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2(poly_x1);
__m512 out1 = _mm512_add_ps(c_ones, tanh_poly_x);
__m512 out2 = _mm512_add_ps(c_half, poly_x2);
__m512 out3 = _mm512_fmsub_ps(poly_x2, tanh_poly_x, out2);
__m512 out4 = _mm512_mul_ps(c_minus_1, out3);
__m512 output = _mm512_mul_ps(out1, out4);
return output;
}
LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_EXP_PS_2DTS(__m512 in) {
const __m512 log2_e = _mm512_set1_ps(1.442695f);
const __m512 half = _mm512_set1_ps(0.5f);
const __m512 c2 = _mm512_set1_ps(0.240226507f);
const __m512 c1 = _mm512_set1_ps(0.452920674f);
const __m512 c0 = _mm512_set1_ps(0.713483036f);
const __m512 x = _mm512_fmadd_ps(in, log2_e, half);
#if 1
const __m512 y = _mm512_sub_ps(x, _mm512_roundscale_round_ps(x, 1, _MM_FROUND_CUR_DIRECTION));
#else
const __m512 y = _mm512_reduce_ps(x, 1);
#endif
const __m512 t1 = _mm512_fmadd_ps( y, c2, c1);
const __m512 two_to_y = _mm512_fmadd_ps( y, t1, c0);
const __m512 exp = _mm512_scalef_ps( two_to_y, x );
return exp;
}
LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_EXP_PS_3DTS(__m512 in) {
const __m512 log2_e = _mm512_set1_ps(1.442695f);
const __m512 half = _mm512_set1_ps(0.5f);
const __m512 c3 = _mm512_set1_ps(0.05550410866f);
const __m512 c2 = _mm512_set1_ps(0.15697034396f);
const __m512 c1 = _mm512_set1_ps(0.49454875509f);
const __m512 c0 = _mm512_set1_ps(0.70654502287f);
const __m512 x = _mm512_fmadd_ps(in, log2_e, half);
#if 1
const __m512 y = _mm512_sub_ps(x, _mm512_roundscale_round_ps(x, 1, _MM_FROUND_CUR_DIRECTION));
#else
const __m512 y = _mm512_reduce_ps(x, 1);
#endif
const __m512 t1 = _mm512_fmadd_ps( y, c3, c2);
const __m512 t2 = _mm512_fmadd_ps( y, t1, c1);
const __m512 two_to_y = _mm512_fmadd_ps( y, t2, c0);
const __m512 exp = _mm512_scalef_ps( two_to_y, x );
return exp;
}
# if defined(__GNUC__) && !defined(__clang__) && !defined(LIBXSMM_INTEL_COMPILER) && !defined(_CRAYC) && 0
LIBXSMM_PRAGMA_OPTIMIZE_OFF /* avoid ICE in case of symbols (-g) */
# endif
/** Generate random number in the interval [0, 1); not thread-safe.
* this is based on xoshiro128+ 1.0, e.g. http://prng.di.unimi.it/xoshiro128plus.c */
LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512i LIBXSMM_INTRINSICS_MM512_RNG_XOSHIRO128P_EPI32(void) {
const __m512i result = _mm512_add_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(0), LIBXSMM_INTRINSICS_MM512_RNG_STATE(3));
const __m512i s = _mm512_slli_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(1), 9);
__m512i t;
LIBXSMM_INTRINSICS_MM512_RNG_STATE(2) = _mm512_xor_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(2), LIBXSMM_INTRINSICS_MM512_RNG_STATE(0));
LIBXSMM_INTRINSICS_MM512_RNG_STATE(3) = _mm512_xor_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(3), LIBXSMM_INTRINSICS_MM512_RNG_STATE(1));
LIBXSMM_INTRINSICS_MM512_RNG_STATE(1) = _mm512_xor_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(1), LIBXSMM_INTRINSICS_MM512_RNG_STATE(2));
LIBXSMM_INTRINSICS_MM512_RNG_STATE(0) = _mm512_xor_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(0), LIBXSMM_INTRINSICS_MM512_RNG_STATE(3));
LIBXSMM_INTRINSICS_MM512_RNG_STATE(2) = _mm512_xor_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(2), s);
t = _mm512_slli_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(3), 11);
LIBXSMM_INTRINSICS_MM512_RNG_STATE(3) = _mm512_or_epi32(t, _mm512_srli_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(3), 32 - 11));
return result;
}
LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_RNG_PS(void) {
const __m512i rng_mantissa = _mm512_srli_epi32( LIBXSMM_INTRINSICS_MM512_RNG_XOSHIRO128P_EPI32(), 9 );
const __m512 one = _mm512_set1_ps(1.0f);
return _mm512_sub_ps(_mm512_castsi512_ps(_mm512_or_epi32(_mm512_set1_epi32(0x3f800000), rng_mantissa)), one);
}
/** Generate random number in the interval [0, 1); thread save, state needs to be managed by user.
* this is based on xoshiro128+ 1.0, e.g. http://prng.di.unimi.it/xoshiro128plus.c */
LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512i LIBXSMM_INTRINSICS_MM512_RNG_XOSHIRO128P_EXTSTATE_EPI32(unsigned int* stateptr) {
__m512i state_0 = _mm512_loadu_si512( stateptr );
__m512i state_1 = _mm512_loadu_si512( stateptr+16 );
__m512i state_2 = _mm512_loadu_si512( stateptr+32 );
__m512i state_3 = _mm512_loadu_si512( stateptr+48 );
const __m512i result = _mm512_add_epi32(state_0, state_3);
const __m512i s = _mm512_slli_epi32(state_1, 9);
__m512i t;
state_2 = _mm512_xor_epi32(state_2, state_0);
state_3 = _mm512_xor_epi32(state_3, state_1);
state_1 = _mm512_xor_epi32(state_1, state_2);
state_0 = _mm512_xor_epi32(state_0, state_3);
state_2 = _mm512_xor_epi32(state_2, s);
_mm512_storeu_si512( stateptr , state_0 );
_mm512_storeu_si512( stateptr+16, state_1 );
_mm512_storeu_si512( stateptr+32, state_2 );