This repository has been archived by the owner on Aug 9, 2022. It is now read-only.
forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathperf_kernel_defs.bzl
66 lines (58 loc) · 2.2 KB
/
perf_kernel_defs.bzl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
load("@fbcode_macros//build_defs:cpp_library.bzl", "cpp_library")
is_dbg_build = native.read_config("fbcode", "build_mode", "").find("dbg") != -1
is_sanitizer = native.read_config("fbcode", "sanitizer", "") != ""
def define_perf_kernels(
prefix,
levels_and_flags,
compiler_common_flags = [],
arch_compiler_common_flags = {},
dependencies = [],
arch_dependencies = [],
external_deps = []):
vectorize_flags = ([
# "-Rpass=loop-vectorize", # Add vectorization information to output
"-DENABLE_VECTORIZATION=1",
"-fveclib=SVML",
] if not is_dbg_build and not is_sanitizer else [])
compiler_specific_flags = {
"clang": vectorize_flags,
"gcc": [],
}
compiler_specific_flags["clang"] += ["-Wno-pass-failed"]
common_srcs = native.glob(
["**/*.cc"],
exclude = [
"**/*_avx512.cc",
"**/*_avx2.cc",
"**/*_avx.cc",
],
)
cpp_headers = native.glob(
["**/*.h"],
)
kernel_targets = {}
for arch, levels_and_flags in levels_and_flags.items():
for level, flags in levels_and_flags:
cpp_library(
name = prefix + "perfkernels_" + level,
srcs = native.glob(["**/*_" + level + ".cc"]),
headers = cpp_headers,
compiler_flags = compiler_common_flags + flags,
arch_compiler_flags = arch_compiler_common_flags,
compiler_specific_flags = compiler_specific_flags,
exported_deps = dependencies,
exported_arch_deps = arch_dependencies,
exported_external_deps = external_deps,
)
kernel_targets.setdefault(arch, []).append(":" + prefix + "perfkernels_" + level)
cpp_library(
name = prefix + "perfkernels",
srcs = common_srcs,
headers = cpp_headers,
compiler_flags = compiler_common_flags,
arch_compiler_flags = arch_compiler_common_flags,
compiler_specific_flags = compiler_specific_flags,
link_whole = True,
exported_arch_deps = kernel_targets.items() + arch_dependencies,
exported_deps = dependencies,
)