Skip to content

Commit

Permalink
Add Read+Seek trait and Write trait support (LaurentMazare#441)
Browse files Browse the repository at this point in the history
* Add APIs to read/write rust streams.

* Fix inverted dependency.

* Refactor similar functions.

* Run cargo fmt.

* Make ReadStream private.

* Drop explicitly.

* Fix some clippy lint.

* Run cargo fmt.
  • Loading branch information
tsubakisakura authored Jan 27, 2022
1 parent 2cce902 commit 6fd63d8
Show file tree
Hide file tree
Showing 10 changed files with 522 additions and 0 deletions.
37 changes: 37 additions & 0 deletions src/nn/var_store.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
//! Variable stores.
use super::Init;
use crate::tensor::Tensor;
use crate::wrappers::stream::ReadSeekAdapter;
use crate::{Device, Kind, TchError};
use std::collections::HashMap;
use std::io::{Read, Seek};
use std::ops::Div;
use std::sync::{Arc, Mutex, MutexGuard};

Expand Down Expand Up @@ -109,6 +111,16 @@ impl VarStore {
Tensor::save_multi(named_tensors.as_slice(), path)
}

/// Saves the var-store variable values to a stream.
///
/// Weight values for all the tensors currently stored in the
/// var-store gets saved in the given stream.
pub fn save_to_stream<W: std::io::Write>(&self, stream: W) -> Result<(), TchError> {
let variables = self.variables_.lock().unwrap();
let named_tensors = variables.named_variables.iter().collect::<Vec<_>>();
Tensor::save_multi_to_stream(named_tensors.as_slice(), stream)
}

/// Loads the var-store variable values from a file.
///
/// Weight values for all the tensors currently stored in the
Expand All @@ -133,6 +145,31 @@ impl VarStore {
Ok(())
}

/// Loads the var-store variable values from a stream.
///
/// Weight values for all the tensors currently stored in the
/// var-store gets loaded from the given stream. Note that the set of
/// variables stored in the var-store is not changed, only the values
/// for these tensors are modified.
pub fn load_from_stream<S: Read + Seek>(&mut self, stream: S) -> Result<(), TchError> {
let adapter = ReadSeekAdapter::new(stream);
let named_tensors = Tensor::load_multi_from_stream_with_device(adapter, self.device)?;
let named_tensors: HashMap<_, _> = named_tensors.into_iter().collect();
let mut variables = self.variables_.lock().unwrap();
for (name, var) in variables.named_variables.iter_mut() {
match named_tensors.get(name) {
Some(src) => crate::no_grad(|| var.f_copy_(src).map_err(|e| e.path_context(name)))?,
None => {
return Err(TchError::TensorNameNotFound(
name.to_string(),
"source stream".to_string(),
));
}
}
}
Ok(())
}

/// Loads the var-store variable values from a file if it exists.
///
/// Weight values for the tensors currently stored in the var-store and the given file get
Expand Down
2 changes: 2 additions & 0 deletions src/wrappers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#[macro_use]
mod utils;

pub use utils::{
get_num_interop_threads, get_num_threads, manual_seed, set_num_interop_threads,
set_num_threads, QEngine,
Expand All @@ -11,6 +12,7 @@ pub mod jit;
pub mod kind;
pub(crate) mod optimizer;
pub(crate) mod scalar;
pub(crate) mod stream;
pub(crate) mod tensor;
pub(crate) mod tensor_fallible_generated;
pub(crate) mod tensor_generated;
26 changes: 26 additions & 0 deletions src/wrappers/stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
use std::io::{Read, Result, Seek, SeekFrom};
use torch_sys::io::ReadStream;

pub struct ReadSeekAdapter<T> {
inner: T,
}

impl<T: Read + Seek> ReadSeekAdapter<T> {
pub fn new(inner: T) -> Self {
Self { inner }
}
}

impl<T: Read> Read for ReadSeekAdapter<T> {
fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
self.inner.read(buf)
}
}

impl<T: Seek> Seek for ReadSeekAdapter<T> {
fn seek(&mut self, pos: SeekFrom) -> Result<u64> {
self.inner.seek(pos)
}
}

impl<T: Read + Seek> ReadStream for ReadSeekAdapter<T> {}
89 changes: 89 additions & 0 deletions src/wrappers/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::stream::ReadSeekAdapter;
use super::utils::{path_to_cstring, ptr_to_string};
use super::{
device::{Cuda, Device},
Expand All @@ -7,7 +8,9 @@ use super::{
use crate::TchError;
use libc::{c_char, c_int, c_void};
use std::borrow::Borrow;
use std::io::{Read, Seek, Write};
use std::path::Path;
use torch_sys::io::ReadStream;
use torch_sys::*;

/// A tensor object.
Expand Down Expand Up @@ -534,6 +537,17 @@ impl Tensor {
Ok(Tensor { c_tensor })
}

/// Loads a tensor from a stream.
///
/// The file format is the same as the one used by the PyTorch C++ API.
pub fn load_from_stream<T: Read + Seek>(stream: T) -> Result<Tensor, TchError> {
let adapter = ReadSeekAdapter::new(stream);
let boxed_stream: Box<Box<dyn ReadStream>> = Box::new(Box::new(adapter));
let c_tensor =
unsafe_torch_err!(at_load_from_stream(Box::into_raw(boxed_stream) as *mut c_void,));
Ok(Tensor { c_tensor })
}

/// Saves a tensor to a file.
///
/// The file format is the same as the one used by the PyTorch C++ API.
Expand All @@ -543,6 +557,18 @@ impl Tensor {
Ok(())
}

/// Saves a tensor to a stream.
///
/// The file format is the same as the one used by the PyTorch C++ API.
pub fn save_to_stream<W: Write>(&self, stream: W) -> Result<(), TchError> {
let boxed_stream: Box<Box<dyn Write>> = Box::new(Box::new(stream));
unsafe_torch_err!(at_save_to_stream(
self.c_tensor,
Box::into_raw(boxed_stream) as *mut c_void,
));
Ok(())
}

/// Saves some named tensors to a file
///
/// The file format is the same as the one used by the PyTorch C++ API.
Expand All @@ -567,6 +593,30 @@ impl Tensor {
Ok(())
}

/// Saves some named tensors to a stream
///
/// The file format is the same as the one used by the PyTorch C++ API.
pub fn save_multi_to_stream<S: AsRef<str>, T: AsRef<Tensor>, W: Write>(
named_tensors: &[(S, T)],
stream: W,
) -> Result<(), TchError> {
let boxed_stream: Box<Box<dyn Write>> = Box::new(Box::new(stream));
let c_tensors = named_tensors.iter().map(|nt| nt.1.as_ref().c_tensor).collect::<Vec<_>>();
let names = named_tensors
.iter()
.map(|nt| nt.0.as_ref().replace(".", "|").into_bytes())
.map(std::ffi::CString::new)
.collect::<Result<Vec<_>, _>>()?;
let name_ptrs = names.iter().map(|n| n.as_ptr()).collect::<Vec<_>>();
unsafe_torch_err!(at_save_multi_to_stream(
c_tensors.as_ptr(),
name_ptrs.as_ptr(),
names.len() as i32,
Box::into_raw(boxed_stream) as *mut c_void,
));
Ok(())
}

/// Loads some named tensors from a file
///
/// The file format is the same as the one used by the PyTorch C++ API.
Expand Down Expand Up @@ -599,6 +649,45 @@ impl Tensor {
Ok(v)
}

/// Loads some named tensors from a stream
///
/// The file format is the same as the one used by the PyTorch C++ API.
pub fn load_multi_from_stream<T: Read + Seek>(
stream: T,
) -> Result<Vec<(String, Tensor)>, TchError> {
let adapter = ReadSeekAdapter::new(stream);
let boxed_stream: Box<Box<dyn ReadStream>> = Box::new(Box::new(adapter));
let mut v: Vec<(String, Tensor)> = vec![];
unsafe_torch_err!(at_load_from_stream_callback(
Box::into_raw(boxed_stream) as *mut c_void,
&mut v as *mut _ as *mut c_void,
add_callback,
false,
0,
));
Ok(v)
}

/// Loads some named tensors from a stream to a given device
///
/// The file format is the same as the one used by the PyTorch C++ API.
pub fn load_multi_from_stream_with_device<T: Read + Seek>(
stream: T,
device: Device,
) -> Result<Vec<(String, Tensor)>, TchError> {
let adapter = ReadSeekAdapter::new(stream);
let boxed_stream: Box<Box<dyn ReadStream>> = Box::new(Box::new(adapter));
let mut v: Vec<(String, Tensor)> = vec![];
unsafe_torch_err!(at_load_from_stream_callback(
Box::into_raw(boxed_stream) as *mut c_void,
&mut v as *mut _ as *mut c_void,
add_callback,
true,
device.c_int(),
));
Ok(v)
}

/// Returns a string representation for the tensor.
///
/// The representation will contain all the tensor element hence may be huge for
Expand Down
52 changes: 52 additions & 0 deletions tests/serialization_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,27 @@ fn save_and_load() {
assert_eq!(Vec::<f64>::from(&t2), vec)
}

#[test]
fn save_to_stream_and_load() {
let filename = std::env::temp_dir().join(format!("tch-write-stream-{}", std::process::id()));
let vec = [3.0, 1.0, 4.0, 1.0, 5.0].to_vec();
let t1 = Tensor::of_slice(&vec);
t1.save_to_stream(std::fs::File::create(&filename).unwrap()).unwrap();
let t2 = Tensor::load(&filename).unwrap();
assert_eq!(Vec::<f64>::from(&t2), vec)
}

#[test]
fn save_and_load_from_stream() {
let filename = std::env::temp_dir().join(format!("tch-read-stream-{}", std::process::id()));
let vec = [3.0, 1.0, 4.0, 1.0, 5.0].to_vec();
let t1 = Tensor::of_slice(&vec);
t1.save(&filename).unwrap();
let reader = std::io::BufReader::new(std::fs::File::open(&filename).unwrap());
let t2 = Tensor::load_from_stream(reader).unwrap();
assert_eq!(Vec::<f64>::from(&t2), vec)
}

#[test]
fn save_and_load_multi() {
let filename = std::env::temp_dir().join(format!("tch2-{}", std::process::id()));
Expand All @@ -23,6 +44,37 @@ fn save_and_load_multi() {
assert_eq!(i64::from(&named_tensors[1].1.sum(tch::Kind::Float)), 57);
}

#[test]
fn save_to_stream_and_load_multi() {
let filename = std::env::temp_dir().join(format!("tch2-write-stream-{}", std::process::id()));
let pi = Tensor::of_slice(&[3.0, 1.0, 4.0, 1.0, 5.0]);
let e = Tensor::of_slice(&[2, 7, 1, 8, 2, 8, 1, 8, 2, 8, 4, 6]);
Tensor::save_multi_to_stream(
&[(&"pi", &pi), (&"e", &e)],
std::fs::File::create(&filename).unwrap(),
)
.unwrap();
let named_tensors = Tensor::load_multi(&filename).unwrap();
assert_eq!(named_tensors.len(), 2);
assert_eq!(named_tensors[0].0, "pi");
assert_eq!(named_tensors[1].0, "e");
assert_eq!(i64::from(&named_tensors[1].1.sum(tch::Kind::Float)), 57);
}

#[test]
fn save_and_load_multi_from_stream() {
let filename = std::env::temp_dir().join(format!("tch2-read-stream-{}", std::process::id()));
let pi = Tensor::of_slice(&[3.0, 1.0, 4.0, 1.0, 5.0]);
let e = Tensor::of_slice(&[2, 7, 1, 8, 2, 8, 1, 8, 2, 8, 4, 6]);
Tensor::save_multi(&[(&"pi", &pi), (&"e", &e)], &filename).unwrap();
let reader = std::io::BufReader::new(std::fs::File::open(&filename).unwrap());
let named_tensors = Tensor::load_multi_from_stream(reader).unwrap();
assert_eq!(named_tensors.len(), 2);
assert_eq!(named_tensors[0].0, "pi");
assert_eq!(named_tensors[1].0, "e");
assert_eq!(i64::from(&named_tensors[1].1.sum(tch::Kind::Float)), 57);
}

#[test]
fn save_and_load_npz() {
let filename = std::env::temp_dir().join(format!("tch3-{}.npz", std::process::id()));
Expand Down
62 changes: 62 additions & 0 deletions tests/var_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,68 @@ fn save_and_load_var_store() {
fs::remove_file(filename).unwrap();
}

#[test]
fn save_to_stream_and_load_var_store() {
let filename =
std::env::temp_dir().join(format!("tch-vs-load-stream-complete-{}", std::process::id()));
let add = |vs: &tch::nn::Path| {
let v = vs.sub("a").sub("b").ones("t2", &[3]);
let u = vs.zeros("t1", &[4]);
let _w = vs.sub("a").sub("b").sub("ccc").ones("t123", &[3]);
let _w = vs.sub("a").sub("b").sub("ccc").ones("t123", &[3]);
(u, v)
};
let vs1 = VarStore::new(Device::Cpu);
let mut vs2 = VarStore::new(Device::Cpu);
let (mut u1, mut v1) = add(&vs1.root());
let (u2, v2) = add(&vs2.root());
tch::no_grad(|| {
u1 += 42.0;
v1 *= 2.0;
});
assert_eq!(f64::from(&u1.mean(Kind::Float)), 42.0);
assert_eq!(f64::from(&v1.mean(Kind::Float)), 2.0);
assert_eq!(f64::from(&u2.mean(Kind::Float)), 0.0);
assert_eq!(f64::from(&v2.mean(Kind::Float)), 1.0);
vs1.save_to_stream(std::fs::File::create(&filename).unwrap()).unwrap();
vs2.load(&filename).unwrap();
assert_eq!(f64::from(&u1.mean(Kind::Float)), 42.0);
assert_eq!(f64::from(&u2.mean(Kind::Float)), 42.0);
assert_eq!(f64::from(&v2.mean(Kind::Float)), 2.0);
fs::remove_file(filename).unwrap();
}

#[test]
fn save_and_load_from_stream_var_store() {
let filename =
std::env::temp_dir().join(format!("tch-vs-load-stream-complete-{}", std::process::id()));
let add = |vs: &tch::nn::Path| {
let v = vs.sub("a").sub("b").ones("t2", &[3]);
let u = vs.zeros("t1", &[4]);
let _w = vs.sub("a").sub("b").sub("ccc").ones("t123", &[3]);
let _w = vs.sub("a").sub("b").sub("ccc").ones("t123", &[3]);
(u, v)
};
let vs1 = VarStore::new(Device::Cpu);
let mut vs2 = VarStore::new(Device::Cpu);
let (mut u1, mut v1) = add(&vs1.root());
let (u2, v2) = add(&vs2.root());
tch::no_grad(|| {
u1 += 42.0;
v1 *= 2.0;
});
assert_eq!(f64::from(&u1.mean(Kind::Float)), 42.0);
assert_eq!(f64::from(&v1.mean(Kind::Float)), 2.0);
assert_eq!(f64::from(&u2.mean(Kind::Float)), 0.0);
assert_eq!(f64::from(&v2.mean(Kind::Float)), 1.0);
vs1.save(&filename).unwrap();
vs2.load_from_stream(std::fs::File::open(&filename).unwrap()).unwrap();
assert_eq!(f64::from(&u1.mean(Kind::Float)), 42.0);
assert_eq!(f64::from(&u2.mean(Kind::Float)), 42.0);
assert_eq!(f64::from(&v2.mean(Kind::Float)), 2.0);
fs::remove_file(filename).unwrap();
}

#[test]
fn save_and_load_partial_var_store() {
let filename =
Expand Down
Loading

0 comments on commit 6fd63d8

Please sign in to comment.