forked from EricLBuehler/mistral.rs
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbuild.rs
63 lines (57 loc) · 2.3 KB
/
build.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#[cfg(feature = "cuda")]
const CUDA_NVCC_FLAGS: Option<&'static str> = option_env!("CUDA_NVCC_FLAGS");
fn main() {
#[cfg(feature = "cuda")]
{
use std::{path::PathBuf, vec};
println!("cargo:rerun-if-changed=build.rs");
let build_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap());
let lib_files = vec!["src/cuda/nonzero_bitwise.cu"];
for lib_file in lib_files.iter() {
println!("cargo:rerun-if-changed={lib_file}");
}
let mut builder = bindgen_cuda::Builder::default()
.kernel_paths(lib_files)
.out_dir(build_dir.clone())
.arg("-std=c++17")
.arg("-O3")
.arg("-U__CUDA_NO_HALF_OPERATORS__")
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
.arg("-U__CUDA_NO_HALF2_OPERATORS__")
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
.arg("--expt-relaxed-constexpr")
.arg("--expt-extended-lambda")
.arg("--use_fast_math")
.arg("--verbose");
// https://github.com/EricLBuehler/mistral.rs/issues/286
if let Some(cuda_nvcc_flags_env) = CUDA_NVCC_FLAGS {
builder = builder.arg("--compiler-options");
//builder = builder.arg("-fPIC -fPIE");
builder = builder.arg(cuda_nvcc_flags_env);
}
let target = std::env::var("TARGET").unwrap();
// https://github.com/EricLBuehler/mistral.rs/issues/588
let out_file = if target.contains("msvc") {
// Windows case
build_dir.join("mistralrscuda.lib")
} else {
build_dir.join("libmistralrscuda.a")
};
builder.build_lib(out_file);
println!("cargo:rustc-link-search={}", build_dir.display());
println!("cargo:rustc-link-lib=mistralrscuda");
println!("cargo:rustc-link-lib=dylib=cudart");
if target.contains("msvc") {
// nothing to link to
} else if target.contains("apple")
|| target.contains("freebsd")
|| target.contains("openbsd")
{
println!("cargo:rustc-link-lib=dylib=c++");
} else if target.contains("android") {
println!("cargo:rustc-link-lib=dylib=c++_shared");
} else {
println!("cargo:rustc-link-lib=dylib=stdc++");
}
}
}