Skip to content

Commit

Permalink
Add --remote-store-headers and fix --remote-execution-headers to …
Browse files Browse the repository at this point in the history
…not impact remote caching (pantsbuild#11501)

Previously, we only had `--remote-execution-headers`. We weren't applying those to the store setup, and we were incorrectly applying them to the remote cache setup.

We considered instead consolidating to only have `--remote-headers`, which applies to both contexts. However, prior art from Bazel shows that it's useful to have service-specific options and we avoid a deprecation warning this way. We could add both `--remote-headers` and `--remote-store-headers`, but that's not done here for simplicity. We can add `--remote-headers` in the future, if necessary.

This PR also refactors to convert the `--remote-oauth-bearer-token-path` into the relevant header in Python, before crossing the FFI boundary. This simplifies our Rust code so that it simply gets a dictionary of headers. All header injection now happens in Python, which will facilitate adding a plugin hook to dynamically set these headers.
  • Loading branch information
Eric-Arellano authored Jan 28, 2021
1 parent c6ad58a commit 2a5346e
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 99 deletions.
6 changes: 2 additions & 4 deletions src/python/pants/engine/internals/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def new_scheduler(
execution_process_cache_namespace=execution_options.process_execution_cache_namespace,
instance_name=execution_options.remote_instance_name,
root_ca_certs_path=execution_options.remote_ca_certs_path,
oauth_bearer_token_path=execution_options.remote_oauth_bearer_token_path,
store_headers=tuple(execution_options.remote_store_headers.items()),
store_thread_count=execution_options.remote_store_thread_count,
store_chunk_bytes=execution_options.remote_store_chunk_bytes,
store_chunk_upload_timeout=execution_options.remote_store_chunk_upload_timeout_seconds,
Expand All @@ -270,9 +270,7 @@ def new_scheduler(
tuple(pair.split("=", 1))
for pair in execution_options.remote_execution_extra_platform_properties
),
execution_headers=tuple(
(k, v) for (k, v) in execution_options.remote_execution_headers.items()
),
execution_headers=tuple(execution_options.remote_execution_headers.items()),
execution_overall_deadline_secs=execution_options.remote_execution_overall_deadline_secs,
)

Expand Down
88 changes: 63 additions & 25 deletions src/python/pants/option/global_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import tempfile
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Optional
from pathlib import Path
from typing import Any, Dict, List, Optional, cast

from pants.base.build_environment import (
get_buildroot,
Expand Down Expand Up @@ -72,7 +73,6 @@ class ExecutionOptions:

remote_instance_name: Optional[str]
remote_ca_certs_path: Optional[str]
remote_oauth_bearer_token_path: Optional[str]

process_execution_local_parallelism: int
process_execution_remote_parallelism: int
Expand All @@ -82,6 +82,7 @@ class ExecutionOptions:
process_execution_local_enable_nailgun: bool

remote_store_server: List[str]
remote_store_headers: Dict[str, str]
remote_store_thread_count: int
remote_store_chunk_bytes: Any
remote_store_chunk_upload_timeout_seconds: int
Expand All @@ -100,6 +101,22 @@ class ExecutionOptions:

@classmethod
def from_bootstrap_options(cls, bootstrap_options: OptionValueContainer) -> ExecutionOptions:
# Possibly insert some headers.
remote_execution_headers = cast(Dict[str, str], bootstrap_options.remote_execution_headers)
remote_store_headers = cast(Dict[str, str], bootstrap_options.remote_store_headers)
if bootstrap_options.remote_oauth_bearer_token_path:
oauth_token = (
Path(bootstrap_options.remote_oauth_bearer_token_path).resolve().read_text().strip()
)
if set(oauth_token).intersection({"\n", "\r"}):
raise OptionsError(
f"OAuth bearer token path {bootstrap_options.remote_oauth_bearer_token_path} "
"must not contain multiple lines."
)
token_header = {"authorization": f"Bearer {oauth_token}"}
remote_execution_headers.update(token_header)
remote_store_headers.update(token_header)

return cls(
# Remote execution strategy.
remote_execution=bootstrap_options.remote_execution,
Expand All @@ -108,7 +125,6 @@ def from_bootstrap_options(cls, bootstrap_options: OptionValueContainer) -> Exec
# General remote setup.
remote_instance_name=bootstrap_options.remote_instance_name,
remote_ca_certs_path=bootstrap_options.remote_ca_certs_path,
remote_oauth_bearer_token_path=bootstrap_options.remote_oauth_bearer_token_path,
# Process execution setup.
process_execution_local_parallelism=bootstrap_options.process_execution_local_parallelism,
process_execution_remote_parallelism=bootstrap_options.process_execution_remote_parallelism,
Expand All @@ -118,6 +134,7 @@ def from_bootstrap_options(cls, bootstrap_options: OptionValueContainer) -> Exec
process_execution_local_enable_nailgun=bootstrap_options.process_execution_local_enable_nailgun,
# Remote store setup.
remote_store_server=bootstrap_options.remote_store_server,
remote_store_headers=remote_store_headers,
remote_store_thread_count=bootstrap_options.remote_store_thread_count,
remote_store_chunk_bytes=bootstrap_options.remote_store_chunk_bytes,
remote_store_chunk_upload_timeout_seconds=bootstrap_options.remote_store_chunk_upload_timeout_seconds,
Expand All @@ -131,7 +148,7 @@ def from_bootstrap_options(cls, bootstrap_options: OptionValueContainer) -> Exec
# Remote execution setup.
remote_execution_server=bootstrap_options.remote_execution_server,
remote_execution_extra_platform_properties=bootstrap_options.remote_execution_extra_platform_properties,
remote_execution_headers=bootstrap_options.remote_execution_headers,
remote_execution_headers=remote_execution_headers,
remote_execution_overall_deadline_secs=bootstrap_options.remote_execution_overall_deadline_secs,
)

Expand All @@ -149,7 +166,6 @@ def from_bootstrap_options(cls, bootstrap_options: OptionValueContainer) -> Exec
# General remote setup.
remote_instance_name=None,
remote_ca_certs_path=None,
remote_oauth_bearer_token_path=None,
# Process execution setup.
process_execution_local_parallelism=_CPU_COUNT,
process_execution_remote_parallelism=128,
Expand All @@ -159,6 +175,7 @@ def from_bootstrap_options(cls, bootstrap_options: OptionValueContainer) -> Exec
process_execution_local_enable_nailgun=False,
# Remote store setup.
remote_store_server=[],
remote_store_headers={},
remote_store_thread_count=1,
remote_store_chunk_bytes=1024 * 1024,
remote_store_chunk_upload_timeout_seconds=60,
Expand Down Expand Up @@ -748,18 +765,32 @@ def register_bootstrap_options(cls, register):
register(
"--remote-oauth-bearer-token-path",
advanced=True,
help="Path to a file containing an oauth token to use for grpc connections to "
"--remote-execution-server and --remote-store-server. If not specified, no "
"authorization will be performed.",
help=(
"Path to a file containing an oauth token to use for gGRPC connections to "
"--remote-execution-server and --remote-store-server.\n\nIf specified, Pants will "
"add a header in the format `authorization: Bearer <token>`. You can also manually "
"add this header via `--remote-execution-headers` and `--remote-store-headers`. "
"Otherwise, no authorization will be performed."
),
)

register(
"--remote-store-server",
advanced=True,
type=list,
default=[],
default=DEFAULT_EXECUTION_OPTIONS.remote_store_server,
help="host:port of grpc server to use as remote execution file store.",
)
register(
"--remote-store-headers",
advanced=True,
type=dict,
default=DEFAULT_EXECUTION_OPTIONS.remote_store_headers,
help=(
"Headers to set on remote store requests. Format: header=value. Pants "
"may add additional headers.\n\nSee `--remote-execution-headers` as well."
),
)
# TODO: Infer this from remote-store-connection-limit.
register(
"--remote-store-thread-count",
Expand Down Expand Up @@ -842,15 +873,17 @@ def register_bootstrap_options(cls, register):
"Format: property=value. Multiple values should be specified as multiple "
"occurrences of this flag. Pants itself may add additional platform properties.",
type=list,
default=[],
default=DEFAULT_EXECUTION_OPTIONS.remote_execution_extra_platform_properties,
)
register(
"--remote-execution-headers",
advanced=True,
help="Headers to set on remote execution requests. "
"Format: header=value. Pants itself may add additional headers.",
type=dict,
default={},
default=DEFAULT_EXECUTION_OPTIONS.remote_execution_headers,
help=(
"Headers to set on remote execution requests. Format: header=value. Pants "
"may add additional headers.\n\nSee `--remote-store-headers` as well."
),
)
register(
"--remote-execution-overall-deadline-secs",
Expand Down Expand Up @@ -1034,15 +1067,20 @@ def validate_instance(cls, opts):
"milliseconds."
)

# Ensure that remote headers are ASCII (gRCP requirement).
for k, v in opts.remote_execution_headers.items():
if not k.isascii():
raise OptionsError(
f"All values in `--remote-execution-headers` must be ASCII "
f"(as required by gRPC), but the key in `{k}: {v}` has non-ASCII characters."
)
if not v.isascii():
raise OptionsError(
f"All values in `--remote-execution-headers` must be ASCII "
f"(as required by gRPC), but the value in `{k}: {v}` has non-ASCII characters."
)
# Ensure that remote headers are ASCII.
def validate_headers(opt_name: str) -> None:
command_line_opt_name = f"--{opt_name.replace('_', '-')}"
for k, v in getattr(opts, opt_name).items():
if not k.isascii():
raise OptionsError(
f"All values in `{command_line_opt_name}` must be ASCII, but the key "
f"in `{k}: {v}` has non-ASCII characters."
)
if not v.isascii():
raise OptionsError(
f"All values in `{command_line_opt_name}` must be ASCII, but the value in "
f"`{k}: {v}` has non-ASCII characters."
)

validate_headers("remote_execution_headers")
validate_headers("remote_store_headers")
24 changes: 15 additions & 9 deletions src/rust/engine/fs/fs_util/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ use futures::FutureExt;
use grpc_util::prost::MessageExt;
use hashing::{Digest, Fingerprint};
use parking_lot::Mutex;
use prost::alloc::collections::BTreeMap;
use rand::seq::SliceRandom;
use serde_derive::Serialize;
use store::{
Expand Down Expand Up @@ -294,14 +295,19 @@ async fn execute(top_match: &clap::ArgMatches<'_>) -> Result<(), ExitError> {
None
};

let oauth_bearer_token =
if let Some(path) = top_match.value_of("oauth-bearer-token-file") {
Some(std::fs::read_to_string(path).map_err(|err| {
format!("Error reading oauth bearer token from {:?}: {}", path, err)
})?)
} else {
None
};
let mut headers = BTreeMap::new();
if let Some(oauth_path) = top_match.value_of("oauth-bearer-token-file") {
let token = std::fs::read_to_string(oauth_path).map_err(|err| {
format!(
"Error reading oauth bearer token from {:?}: {}",
oauth_path, err
)
})?;
headers.insert(
"authorization".to_owned(),
format!("Bearer {}", token.trim()),
);
}

// Randomize CAS address order to avoid thundering herds from common config.
let mut cas_addresses = cas_address.map(str::to_owned).collect::<Vec<_>>();
Expand All @@ -316,7 +322,7 @@ async fn execute(top_match: &clap::ArgMatches<'_>) -> Result<(), ExitError> {
.value_of("remote-instance-name")
.map(str::to_owned),
root_ca_certs,
oauth_bearer_token,
headers,
value_t!(top_match.value_of("thread-count"), usize).expect("Invalid thread count"),
chunk_size,
// This deadline is really only in place because otherwise DNS failures
Expand Down
4 changes: 2 additions & 2 deletions src/rust/engine/fs/store/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ impl Store {
cas_addresses: Vec<String>,
instance_name: Option<String>,
root_ca_certs: Option<Vec<u8>>,
oauth_bearer_token: Option<String>,
headers: BTreeMap<String, String>,
thread_count: usize,
chunk_size_bytes: usize,
upload_timeout: Duration,
Expand All @@ -270,7 +270,7 @@ impl Store {
cas_addresses,
instance_name,
root_ca_certs,
oauth_bearer_token,
headers,
thread_count,
chunk_size_bytes,
upload_timeout,
Expand Down
9 changes: 1 addition & 8 deletions src/rust/engine/fs/store/src/remote.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ impl ByteStore {
cas_addresses: Vec<String>,
instance_name: Option<String>,
root_ca_certs: Option<Vec<u8>>,
oauth_bearer_token: Option<String>,
headers: BTreeMap<String, String>,
_thread_count: usize,
chunk_size_bytes: usize,
upload_timeout: Duration,
Expand All @@ -64,7 +64,6 @@ impl ByteStore {
} else {
"http"
};

let cas_addresses_with_scheme: Vec<_> = cas_addresses
.iter()
.map(|addr| format!("{}://{}", scheme, addr))
Expand All @@ -86,12 +85,6 @@ impl ByteStore {
}

let channel = tonic::transport::Channel::balance_list(endpoints.iter().cloned());

let headers = oauth_bearer_token
.iter()
.map(|t| ("authorization".to_owned(), format!("Bearer {}", t.trim())))
.collect::<BTreeMap<_, _>>();

let interceptor = if headers.is_empty() {
None
} else {
Expand Down
16 changes: 10 additions & 6 deletions src/rust/engine/process_executor/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,19 +218,23 @@ async fn main() {
None
};

let oauth_bearer_token = if let Some(ref path) = args.cas_oauth_bearer_token_path {
Some(std::fs::read_to_string(path).expect("Error reading oauth bearer token file"))
} else {
None
};
let mut headers = BTreeMap::new();
if let Some(ref oauth_path) = args.cas_oauth_bearer_token_path {
let token =
std::fs::read_to_string(oauth_path).expect("Error reading oauth bearer token file");
headers.insert(
"authorization".to_owned(),
format!("Bearer {}", token.trim()),
);
}

Store::with_remote(
executor.clone(),
local_store_path,
vec![cas_server.clone()],
args.remote_instance_name.clone(),
root_ca_certs,
oauth_bearer_token,
headers,
1,
args.upload_chunk_bytes,
Duration::from_secs(30),
Expand Down
Loading

0 comments on commit 2a5346e

Please sign in to comment.