Skip to content

Commit

Permalink
Add a dedicated pyo3-tch package.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed May 29, 2023
1 parent 82befd7 commit 2ff90de
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 37 deletions.
6 changes: 5 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ memmap2 = { version = "0.6.1", optional = true }
anyhow = "1"

[workspace]
members = ["torch-sys", "examples/python-extension"]
members = [
"torch-sys",
"pyo3-tch",
"examples/python-extension",
]

[features]
download-libtorch = ["torch-sys/download-libtorch"]
Expand Down
1 change: 1 addition & 0 deletions examples/python-extension/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.18.3", features = ["extension-module"] }
pyo3-tch = { path = "../../pyo3-tch", version = "0.13.0" }
tch = { path = "../..", features = ["python-extension"], version = "0.13.0" }
torch-sys = { path = "../../torch-sys", features = ["python-extension"], version = "0.13.0" }
38 changes: 2 additions & 36 deletions examples/python-extension/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,43 +1,9 @@
use pyo3::prelude::*;
use pyo3::{
exceptions::{PyTypeError, PyValueError},
AsPyPointer,
};

struct PyTensor(tch::Tensor);

fn wrap_tch_err(err: tch::TchError) -> PyErr {
PyErr::new::<PyValueError, _>(format!("{err:?}"))
}

impl<'source> FromPyObject<'source> for PyTensor {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let ptr = ob.as_ptr() as *mut tch::python::CPyObject;
let tensor = unsafe { tch::Tensor::pyobject_unpack(ptr) };
tensor
.map_err(wrap_tch_err)?
.ok_or_else(|| {
let type_ = ob.get_type();
PyErr::new::<PyTypeError, _>(format!("expected a torch.Tensor, got {type_}"))
})
.map(PyTensor)
}
}

impl IntoPy<PyObject> for PyTensor {
fn into_py(self, py: Python<'_>) -> PyObject {
// There is no fallible alternative to ToPyObject/IntoPy at the moment so we return
// None on errors. https://github.com/PyO3/pyo3/issues/1813
self.0.pyobject_wrap().map_or_else(
|_| py.None(),
|ptr| unsafe { PyObject::from_owned_ptr(py, ptr as *mut pyo3::ffi::PyObject) },
)
}
}
use pyo3_tch::{wrap_tch_err, PyTensor};

#[pyfunction]
fn add_one(tensor: PyTensor) -> PyResult<PyTensor> {
let tensor = tensor.0.f_add_scalar(1.0).map_err(wrap_tch_err)?;
let tensor = tensor.f_add_scalar(1.0).map_err(wrap_tch_err)?;
Ok(PyTensor(tensor))
}

Expand Down
17 changes: 17 additions & 0 deletions pyo3-tch/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[package]
name = "pyo3-tch"
version = "0.13.0"
authors = ["Laurent Mazare <[email protected]>"]
edition = "2021"
build = "build.rs"

description = "Manipulate PyTorch tensors from a Python extension via PyO3/tch."
repository = "https://github.com/LaurentMazare/tch-rs"
keywords = ["pytorch", "deep-learning", "machine-learning"]
categories = ["science"]
license = "MIT/Apache-2.0"

[dependencies]
tch = { path = "..", features = ["python-extension"], version = "0.13.0" }
torch-sys = { path = "../torch-sys", features = ["python-extension"], version = "0.13.0" }
pyo3 = { version = "0.18.3", features = ["extension-module"] }
14 changes: 14 additions & 0 deletions pyo3-tch/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
fn main() {
let os = std::env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS");
match os.as_str() {
"linux" | "windows" => {
if let Some(lib_path) = std::env::var_os("DEP_TCH_LIBTORCH_LIB") {
println!("cargo:rustc-link-arg=-Wl,-rpath={}", lib_path.to_string_lossy());
}
println!("cargo:rustc-link-arg=-Wl,--no-as-needed");
println!("cargo:rustc-link-arg=-Wl,--copy-dt-needed-entries");
println!("cargo:rustc-link-arg=-ltorch");
}
_ => {}
}
}
46 changes: 46 additions & 0 deletions pyo3-tch/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use pyo3::prelude::*;
use pyo3::{
exceptions::{PyTypeError, PyValueError},
AsPyPointer,
};
pub use tch;
pub use torch_sys;

pub struct PyTensor(pub tch::Tensor);

impl std::ops::Deref for PyTensor {
type Target = tch::Tensor;

fn deref(&self) -> &Self::Target {
&self.0
}
}

pub fn wrap_tch_err(err: tch::TchError) -> PyErr {
PyErr::new::<PyValueError, _>(format!("{err:?}"))
}

impl<'source> FromPyObject<'source> for PyTensor {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let ptr = ob.as_ptr() as *mut tch::python::CPyObject;
let tensor = unsafe { tch::Tensor::pyobject_unpack(ptr) };
tensor
.map_err(wrap_tch_err)?
.ok_or_else(|| {
let type_ = ob.get_type();
PyErr::new::<PyTypeError, _>(format!("expected a torch.Tensor, got {type_}"))
})
.map(PyTensor)
}
}

impl IntoPy<PyObject> for PyTensor {
fn into_py(self, py: Python<'_>) -> PyObject {
// There is no fallible alternative to ToPyObject/IntoPy at the moment so we return
// None on errors. https://github.com/PyO3/pyo3/issues/1813
self.0.pyobject_wrap().map_or_else(
|_| py.None(),
|ptr| unsafe { PyObject::from_owned_ptr(py, ptr as *mut pyo3::ffi::PyObject) },
)
}
}

0 comments on commit 2ff90de

Please sign in to comment.