forked from LaurentMazare/tch-rs
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
82befd7
commit 2ff90de
Showing
6 changed files
with
85 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
[package] | ||
name = "pyo3-tch" | ||
version = "0.13.0" | ||
authors = ["Laurent Mazare <[email protected]>"] | ||
edition = "2021" | ||
build = "build.rs" | ||
|
||
description = "Manipulate PyTorch tensors from a Python extension via PyO3/tch." | ||
repository = "https://github.com/LaurentMazare/tch-rs" | ||
keywords = ["pytorch", "deep-learning", "machine-learning"] | ||
categories = ["science"] | ||
license = "MIT/Apache-2.0" | ||
|
||
[dependencies] | ||
tch = { path = "..", features = ["python-extension"], version = "0.13.0" } | ||
torch-sys = { path = "../torch-sys", features = ["python-extension"], version = "0.13.0" } | ||
pyo3 = { version = "0.18.3", features = ["extension-module"] } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
fn main() { | ||
let os = std::env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS"); | ||
match os.as_str() { | ||
"linux" | "windows" => { | ||
if let Some(lib_path) = std::env::var_os("DEP_TCH_LIBTORCH_LIB") { | ||
println!("cargo:rustc-link-arg=-Wl,-rpath={}", lib_path.to_string_lossy()); | ||
} | ||
println!("cargo:rustc-link-arg=-Wl,--no-as-needed"); | ||
println!("cargo:rustc-link-arg=-Wl,--copy-dt-needed-entries"); | ||
println!("cargo:rustc-link-arg=-ltorch"); | ||
} | ||
_ => {} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
use pyo3::prelude::*; | ||
use pyo3::{ | ||
exceptions::{PyTypeError, PyValueError}, | ||
AsPyPointer, | ||
}; | ||
pub use tch; | ||
pub use torch_sys; | ||
|
||
pub struct PyTensor(pub tch::Tensor); | ||
|
||
impl std::ops::Deref for PyTensor { | ||
type Target = tch::Tensor; | ||
|
||
fn deref(&self) -> &Self::Target { | ||
&self.0 | ||
} | ||
} | ||
|
||
pub fn wrap_tch_err(err: tch::TchError) -> PyErr { | ||
PyErr::new::<PyValueError, _>(format!("{err:?}")) | ||
} | ||
|
||
impl<'source> FromPyObject<'source> for PyTensor { | ||
fn extract(ob: &'source PyAny) -> PyResult<Self> { | ||
let ptr = ob.as_ptr() as *mut tch::python::CPyObject; | ||
let tensor = unsafe { tch::Tensor::pyobject_unpack(ptr) }; | ||
tensor | ||
.map_err(wrap_tch_err)? | ||
.ok_or_else(|| { | ||
let type_ = ob.get_type(); | ||
PyErr::new::<PyTypeError, _>(format!("expected a torch.Tensor, got {type_}")) | ||
}) | ||
.map(PyTensor) | ||
} | ||
} | ||
|
||
impl IntoPy<PyObject> for PyTensor { | ||
fn into_py(self, py: Python<'_>) -> PyObject { | ||
// There is no fallible alternative to ToPyObject/IntoPy at the moment so we return | ||
// None on errors. https://github.com/PyO3/pyo3/issues/1813 | ||
self.0.pyobject_wrap().map_or_else( | ||
|_| py.None(), | ||
|ptr| unsafe { PyObject::from_owned_ptr(py, ptr as *mut pyo3::ffi::PyObject) }, | ||
) | ||
} | ||
} |