diff --git a/src/rust/engine/fs/store/src/lib.rs b/src/rust/engine/fs/store/src/lib.rs index 68914f44068..b18bec879e3 100644 --- a/src/rust/engine/fs/store/src/lib.rs +++ b/src/rust/engine/fs/store/src/lib.rs @@ -42,7 +42,7 @@ use std::fmt::{self, Debug, Display}; use std::fs::hard_link; use std::fs::OpenOptions; use std::future::Future; -use std::io::{self, Read, Write}; +use std::io::{self, Read, Seek, Write}; use std::os::unix::fs::{symlink, OpenOptionsExt, PermissionsExt}; use std::path::{Path, PathBuf}; use std::sync::{Arc, Weak}; @@ -237,30 +237,60 @@ impl RemoteStore { } /// Download the digest to the local byte store from this remote store. The function `f_remote` - /// can be used to validate the bytes. - async fn download_digest_to_local< - FRemote: Fn(Bytes) -> Result<(), String> + Send + Sync + 'static, - >( + /// can be used to validate the bytes (NB. if provided, the whole value will be buffered into + /// memory to provide the `Bytes` argument, and thus `f_remote` should only be used for small digests). + async fn download_digest_to_local( &self, local_store: local::ByteStore, digest: Digest, entry_type: EntryType, - f_remote: FRemote, + f_remote: Option<&(dyn Fn(Bytes) -> Result<(), String> + Send + Sync + 'static)>, ) -> Result<(), StoreError> { let remote_store = self.store.clone(); + let create_missing = || { + StoreError::MissingDigest( + "Was not present in either the local or remote store".to_owned(), + digest, + ) + }; self .maybe_download(digest, async move { - let bytes = remote_store.load_bytes(digest).await?.ok_or_else(|| { - StoreError::MissingDigest( - "Was not present in either the local or remote store".to_owned(), - digest, - ) - })?; + let stored_digest = if digest.size_bytes <= IMMUTABLE_FILE_SIZE_LIMIT || f_remote.is_some() + { + // (if there's a function to call, always just buffer fully into memory) + let bytes = remote_store + .load_bytes(digest) + .await? + .ok_or_else(create_missing)?; + if let Some(f_remote) = f_remote { + f_remote(bytes.clone())?; + } + local_store + .store_bytes(entry_type, None, bytes, true) + .await? + } else { + assert!(f_remote.is_none()); + // TODO(#18048): choose a file that can be plopped into the local store directly, when + // large files are stored there + let file = tokio::task::spawn_blocking(tempfile::tempfile) + .await + .map_err(|e| e.to_string())??; + let file = tokio::fs::File::from_std(file); - f_remote(bytes.clone())?; - let stored_digest = local_store - .store_bytes(entry_type, None, bytes, true) - .await?; + let file = remote_store + .load_file(digest, file) + .await? + .ok_or_else(create_missing)?; + + let file = file.into_std().await; + local_store + .store(entry_type, true, true, move || { + let mut file = file.try_clone()?; + file.rewind()?; + Ok(file) + }) + .await? + }; if digest == stored_digest { Ok(()) } else { @@ -510,12 +540,7 @@ impl Store { ) -> Result { // No transformation or verification is needed for files. self - .load_bytes_with( - EntryType::File, - digest, - move |v: &[u8]| Ok(f(v)), - |_: Bytes| Ok(()), - ) + .load_bytes_with(EntryType::File, digest, move |v: &[u8]| Ok(f(v)), None) .await } @@ -687,13 +712,13 @@ impl Store { }, // Eagerly verify that CAS-returned Directories are canonical, so that we don't write them // into our local store. - move |bytes: Bytes| { + Some(&move |bytes| { let directory = remexec::Directory::decode(bytes).map_err(|e| { format!("CAS returned Directory proto for {digest:?} which was not valid: {e:?}") })?; protos::verify_directory_canonical(digest, &directory)?; Ok(()) - }, + }), ) .await } @@ -721,13 +746,12 @@ impl Store { async fn load_bytes_with< T: Send + 'static, FLocal: Fn(&[u8]) -> Result + Clone + Send + Sync + 'static, - FRemote: Fn(Bytes) -> Result<(), String> + Send + Sync + 'static, >( &self, entry_type: EntryType, digest: Digest, f_local: FLocal, - f_remote: FRemote, + f_remote: Option<&(dyn Fn(Bytes) -> Result<(), String> + Send + Sync + 'static)>, ) -> Result { if let Some(bytes_res) = self .local @@ -997,7 +1021,7 @@ impl Store { .into_iter() .map(|file_digest| async move { if let Err(e) = remote - .download_digest_to_local(self.local.clone(), file_digest, EntryType::File, |_| Ok(())) + .download_digest_to_local(self.local.clone(), file_digest, EntryType::File, None) .await { log::debug!("Missing file digest from remote store: {:?}", file_digest); diff --git a/src/rust/engine/fs/store/src/remote.rs b/src/rust/engine/fs/store/src/remote.rs index 867183eb5fa..cf5a7fc1a78 100644 --- a/src/rust/engine/fs/store/src/remote.rs +++ b/src/rust/engine/fs/store/src/remote.rs @@ -9,11 +9,14 @@ use std::sync::Arc; use std::time::{Duration, Instant}; use async_oncecell::OnceCell; -use bytes::{Bytes, BytesMut}; -use futures::Future; +use async_trait::async_trait; +use bytes::Bytes; use futures::StreamExt; +use futures::{Future, FutureExt}; use grpc_util::retry::{retry_call, status_is_retryable}; -use grpc_util::{headers_to_http_header_map, layered_service, status_to_str, LayeredService}; +use grpc_util::{ + headers_to_http_header_map, layered_service, status_ref_to_str, status_to_str, LayeredService, +}; use hashing::Digest; use log::Level; use protos::gen::build::bazel::remote::execution::v2 as remexec; @@ -23,6 +26,8 @@ use remexec::{ content_addressable_storage_client::ContentAddressableStorageClient, BatchUpdateBlobsRequest, ServerCapabilities, }; +use tokio::io::{AsyncSeekExt, AsyncWrite, AsyncWriteExt}; +use tokio::sync::Mutex; use tonic::{Code, Request, Status}; use workunit_store::{in_workunit, ObservationMetric}; @@ -57,10 +62,36 @@ enum ByteStoreError { Other(String), } +impl ByteStoreError { + fn retryable(&self) -> bool { + match self { + ByteStoreError::Grpc(status) => status_is_retryable(status), + ByteStoreError::Other(_) => false, + } + } +} + +impl From for ByteStoreError { + fn from(status: Status) -> ByteStoreError { + ByteStoreError::Grpc(status) + } +} + +impl From for ByteStoreError { + fn from(string: String) -> ByteStoreError { + ByteStoreError::Other(string) + } +} +impl From for ByteStoreError { + fn from(err: std::io::Error) -> ByteStoreError { + ByteStoreError::Other(err.to_string()) + } +} + impl fmt::Display for ByteStoreError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - ByteStoreError::Grpc(status) => fmt::Display::fmt(status, f), + ByteStoreError::Grpc(status) => fmt::Display::fmt(&status_ref_to_str(status), f), ByteStoreError::Other(msg) => fmt::Display::fmt(msg, f), } } @@ -68,6 +99,29 @@ impl fmt::Display for ByteStoreError { impl std::error::Error for ByteStoreError {} +/// Places that write the result of a remote `load` +#[async_trait] +trait LoadDestination: AsyncWrite + Send + Sync + Unpin + 'static { + /// Clear out the writer and start again, if there's been previous contents written + async fn reset(&mut self) -> std::io::Result<()>; +} + +#[async_trait] +impl LoadDestination for tokio::fs::File { + async fn reset(&mut self) -> std::io::Result<()> { + self.rewind().await?; + self.set_len(0).await + } +} + +#[async_trait] +impl LoadDestination for Vec { + async fn reset(&mut self) -> std::io::Result<()> { + self.clear(); + Ok(()) + } +} + impl ByteStore { // TODO: Consider extracting these options to a struct with `impl Default`, similar to // `super::LocalOptions`. @@ -165,16 +219,10 @@ impl ByteStore { retry_call( mmap, |mmap| self.store_bytes_source(digest, move |range| Bytes::copy_from_slice(&mmap[range])), - |err| match err { - ByteStoreError::Grpc(status) => status_is_retryable(status), - _ => false, - }, + ByteStoreError::retryable, ) .await - .map_err(|err| match err { - ByteStoreError::Grpc(status) => status_to_str(status).into(), - ByteStoreError::Other(msg) => msg.into(), - }) + .map_err(|e| e.to_string().into()) } pub async fn store_bytes(&self, bytes: Bytes) -> Result<(), String> { @@ -182,16 +230,10 @@ impl ByteStore { retry_call( bytes, |bytes| self.store_bytes_source(digest, move |range| bytes.slice(range)), - |err| match err { - ByteStoreError::Grpc(status) => status_is_retryable(status), - _ => false, - }, + ByteStoreError::retryable, ) .await - .map_err(|err| match err { - ByteStoreError::Grpc(status) => status_to_str(status), - ByteStoreError::Other(msg) => msg, - }) + .map_err(|e| e.to_string()) } async fn store_bytes_source( @@ -325,7 +367,11 @@ impl ByteStore { .await } - pub async fn load_bytes(&self, digest: Digest) -> Result, String> { + async fn load_monomorphic( + &self, + digest: Digest, + destination: &mut dyn LoadDestination, + ) -> Result { let start = Instant::now(); let store = self.clone(); let instance_name = store.instance_name.clone().unwrap_or_default(); @@ -346,25 +392,16 @@ impl ByteStore { }; let client = self.byte_stream_client.as_ref().clone(); + let destination = Arc::new(Mutex::new(destination)); + let result_future = retry_call( - (client, request), - move |(mut client, request)| async move { - let mut start_opt = Some(Instant::now()); - let stream_result = client.read(request).await; - - let mut stream = match stream_result { - Ok(response) => response.into_inner(), - Err(status) => { - return match status.code() { - Code::NotFound => Ok(None), - _ => Err(status), - } - } - }; + (client, request, destination), + move |(mut client, request, destination)| { + async move { + let mut start_opt = Some(Instant::now()); + let response = client.read(request).await?; - let read_result_closure = async { - let mut buf = BytesMut::with_capacity(digest.size_bytes); - while let Some(response) = stream.next().await { + let mut stream = response.into_inner().inspect(|_| { // Record the observed time to receive the first response for this read. if let Some(start) = start_opt.take() { if let Some(workunit_store_handle) = workunit_store::get_workunit_store_handle() { @@ -377,31 +414,31 @@ impl ByteStore { } } } + }); - buf.extend_from_slice(&(response?).data); + let mut writer = destination.lock().await; + writer.reset().await?; + while let Some(response) = stream.next().await { + writer.write_all(&(response?).data).await?; } - Ok(buf.freeze()) - }; - - let read_result: Result = read_result_closure.await; - match read_result { - Ok(bytes) => Ok(Some(bytes)), - Err(status) => match status.code() { - Code::NotFound => Ok(None), - _ => Err(status), - }, + Ok(()) } + .map(|read_result| match read_result { + Ok(()) => Ok(true), + Err(ByteStoreError::Grpc(status)) if status.code() == Code::NotFound => Ok(false), + Err(err) => Err(err), + }) }, - status_is_retryable, + ByteStoreError::retryable, ); in_workunit!( - "load_bytes", + "load", Level::Trace, desc = Some(workunit_desc), |workunit| async move { - let result = result_future.await.map_err(status_to_str); + let result = result_future.await.map_err(|e| e.to_string()); workunit.record_observation( ObservationMetric::RemoteStoreReadBlobTimeMicros, start.elapsed().as_micros() as u64, @@ -418,6 +455,40 @@ impl ByteStore { .await } + async fn load( + &self, + digest: Digest, + mut destination: W, + ) -> Result, String> { + // TODO(#18231): compute the digest as we write, to avoid needing a second pass to validate + if self.load_monomorphic(digest, &mut destination).await? { + Ok(Some(destination)) + } else { + Ok(None) + } + } + + /// Load the data for `digest` (if it exists in the remote store) into memory. + pub async fn load_bytes(&self, digest: Digest) -> Result, String> { + let result = self + .load(digest, Vec::with_capacity(digest.size_bytes)) + .await?; + Ok(result.map(Bytes::from)) + } + + /// Write the data for `digest` (if it exists in the remote store) into `file`. + pub async fn load_file( + &self, + digest: Digest, + file: tokio::fs::File, + ) -> Result, String> { + let mut result = self.load(digest, file).await; + if let Ok(Some(ref mut file)) = result { + file.rewind().await.map_err(|e| e.to_string())?; + } + result + } + /// /// Given a collection of Digests (digests), /// returns the set of digests from that collection not present in the CAS. diff --git a/src/rust/engine/fs/store/src/remote_tests.rs b/src/rust/engine/fs/store/src/remote_tests.rs index e4c21632f16..726f044d428 100644 --- a/src/rust/engine/fs/store/src/remote_tests.rs +++ b/src/rust/engine/fs/store/src/remote_tests.rs @@ -8,6 +8,7 @@ use grpc_util::tls; use hashing::Digest; use mock::StubCAS; use testutil::data::{TestData, TestDirectory}; +use tokio::io::AsyncReadExt; use workunit_store::WorkunitStore; use crate::remote::ByteStore; @@ -27,6 +28,36 @@ async fn loads_file() { ); } +#[tokio::test] +async fn loads_huge_file_via_temp_file() { + // 5MB of data + let testdata = TestData::new(&"12345".repeat(MEGABYTES)); + + let _ = WorkunitStore::setup_for_tests(); + let cas = StubCAS::builder() + .chunk_size_bytes(MEGABYTES) + .file(&testdata) + .build(); + + let file = tokio::task::spawn_blocking(tempfile::tempfile) + .await + .unwrap() + .unwrap(); + let file = tokio::fs::File::from_std(file); + + let mut file = new_byte_store(&cas) + .load_file(testdata.digest(), file) + .await + .unwrap() + .unwrap(); + + let mut buf = String::new(); + file.read_to_string(&mut buf).await.unwrap(); + assert_eq!(buf.len(), testdata.len()); + // (assert_eq! means failures unhelpfully print a 5MB string) + assert!(buf == testdata.string()); +} + #[tokio::test] async fn missing_file() { let _ = WorkunitStore::setup_for_tests(); diff --git a/src/rust/engine/fs/store/src/tests.rs b/src/rust/engine/fs/store/src/tests.rs index e18dd9dd3a6..844cb3f8489 100644 --- a/src/rust/engine/fs/store/src/tests.rs +++ b/src/rust/engine/fs/store/src/tests.rs @@ -179,6 +179,36 @@ async fn load_file_falls_back_and_backfills() { ); } +#[tokio::test] +async fn load_file_falls_back_and_backfills_for_huge_file() { + let dir = TempDir::new().unwrap(); + + // 5MB of data + let testdata = TestData::new(&"12345".repeat(MEGABYTES)); + + let _ = WorkunitStore::setup_for_tests(); + let cas = StubCAS::builder() + .chunk_size_bytes(MEGABYTES) + .file(&testdata) + .build(); + + assert!( + load_file_bytes(&new_store(dir.path(), &cas.address()), testdata.digest()).await + == Ok(testdata.bytes()), + "Read from CAS" + ); + assert_eq!(1, cas.read_request_count()); + assert!( + crate::local_tests::load_file_bytes( + &crate::local_tests::new_store(dir.path()), + testdata.digest(), + ) + .await + == Ok(Some(testdata.bytes())), + "Read from local cache" + ); +} + #[tokio::test] async fn load_directory_falls_back_and_backfills() { let dir = TempDir::new().unwrap(); diff --git a/src/rust/engine/grpc_util/src/lib.rs b/src/rust/engine/grpc_util/src/lib.rs index 20939bd1571..cf56fb977ec 100644 --- a/src/rust/engine/grpc_util/src/lib.rs +++ b/src/rust/engine/grpc_util/src/lib.rs @@ -148,10 +148,14 @@ pub fn headers_to_http_header_map(headers: &BTreeMap) -> Result< Ok(HeaderMap::from_iter(http_headers)) } -pub fn status_to_str(status: tonic::Status) -> String { +pub fn status_ref_to_str(status: &tonic::Status) -> String { format!("{:?}: {:?}", status.code(), status.message()) } +pub fn status_to_str(status: tonic::Status) -> String { + status_ref_to_str(&status) +} + #[derive(Clone)] pub struct CountErrorsService { service: S, diff --git a/src/rust/engine/grpc_util/src/retry.rs b/src/rust/engine/grpc_util/src/retry.rs index 1e28bfaedc4..e17cc661225 100644 --- a/src/rust/engine/grpc_util/src/retry.rs +++ b/src/rust/engine/grpc_util/src/retry.rs @@ -21,10 +21,10 @@ pub fn status_is_retryable(status: &Status) -> bool { /// Retry a gRPC client operation using exponential back-off to delay between attempts. #[inline] -pub async fn retry_call(client: C, f: F, is_retryable: G) -> Result +pub async fn retry_call(client: C, mut f: F, is_retryable: G) -> Result where C: Clone, - F: Fn(C) -> Fut, + F: FnMut(C) -> Fut, G: Fn(&E) -> bool, Fut: Future>, {