Skip to content

Commit

Permalink
Split logic into DynamicRemoteOptions.from_option into different code…
Browse files Browse the repository at this point in the history
… paths (pantsbuild#16172)

I am planning to [add entry point support for the auth plugin](pantsbuild#16212) so the user doesn't have to specify the plugin path. This will require some changes to the logic here and having a function that deals with loading the plugin (not mixed with the other code paths) will make it simpler and easier to grok.

This duplicated some code (extracting options and passing them DynamicRemoteOptions) but I think having 3 code paths for the different use cases justifies that and makes this code more readable and hence more maintainable over time.

[ci skip-build-wheels]
[ci skip-rust]
  • Loading branch information
asherf authored Jul 22, 2022
1 parent 3a9083b commit d089127
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 72 deletions.
218 changes: 147 additions & 71 deletions src/python/pants/option/global_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,51 @@ def disabled(cls) -> DynamicRemoteOptions:
execution_rpc_concurrency=DEFAULT_EXECUTION_OPTIONS.remote_execution_rpc_concurrency,
)

@classmethod
def _use_oauth_token(cls, bootstrap_options: OptionValueContainer) -> DynamicRemoteOptions:
oauth_token = (
Path(bootstrap_options.remote_oauth_bearer_token_path).resolve().read_text().strip()
)
if set(oauth_token).intersection({"\n", "\r"}):
raise OptionsError(
f"OAuth bearer token path {bootstrap_options.remote_oauth_bearer_token_path} "
"must not contain multiple lines."
)
if bootstrap_options.remote_auth_plugin:
raise OptionsError(
"OAuth bearer token path can not be used when setting the `[GLOBAL].remote_auth_plugin` option"
)

token_header = {"authorization": f"Bearer {oauth_token}"}
execution = cast(bool, bootstrap_options.remote_execution)
cache_read = cast(bool, bootstrap_options.remote_cache_read)
cache_write = cast(bool, bootstrap_options.remote_cache_write)
store_address = cast("str | None", bootstrap_options.remote_store_address)
execution_address = cast("str | None", bootstrap_options.remote_execution_address)
instance_name = cast("str | None", bootstrap_options.remote_instance_name)
execution_headers = cast("dict[str, str]", bootstrap_options.remote_execution_headers)
store_headers = cast("dict[str, str]", bootstrap_options.remote_store_headers)
parallelism = cast(int, bootstrap_options.process_execution_remote_parallelism)
store_rpc_concurrency = cast(int, bootstrap_options.remote_store_rpc_concurrency)
cache_rpc_concurrency = cast(int, bootstrap_options.remote_cache_rpc_concurrency)
execution_rpc_concurrency = cast(int, bootstrap_options.remote_execution_rpc_concurrency)
execution_headers.update(token_header)
store_headers.update(token_header)
return cls(
execution=execution,
cache_read=cache_read,
cache_write=cache_write,
instance_name=instance_name,
store_address=cls._normalize_address(store_address),
execution_address=cls._normalize_address(execution_address),
store_headers=store_headers,
execution_headers=execution_headers,
parallelism=parallelism,
store_rpc_concurrency=store_rpc_concurrency,
cache_rpc_concurrency=cache_rpc_concurrency,
execution_rpc_concurrency=execution_rpc_concurrency,
)

@classmethod
def from_options(
cls,
Expand All @@ -239,10 +284,29 @@ def from_options(
execution = cast(bool, bootstrap_options.remote_execution)
cache_read = cast(bool, bootstrap_options.remote_cache_read)
cache_write = cast(bool, bootstrap_options.remote_cache_write)

if not (execution or cache_read or cache_write):
return cls.disabled(), None
if (
bootstrap_options.remote_auth_plugin
and bootstrap_options.remote_oauth_bearer_token_path
):
raise OptionsError(
"Both `[GLOBAL].remote_auth_plugin` and `[GLOBAL].remote_auth_plugin` `[GLOBAL].remote_oauth_bearer_token_path` are set. This is not supported. Only one of those should be set in order to provide auth information for remote cache."
)
if bootstrap_options.remote_oauth_bearer_token_path:
return cls._use_oauth_token(bootstrap_options), None

if bootstrap_options.remote_auth_plugin:
return cls._use_auth_plugin(
bootstrap_options, full_options=full_options, env=env, prior_result=prior_result
)
return cls._use_no_auth(bootstrap_options), None

@classmethod
def _use_no_auth(cls, bootstrap_options: OptionValueContainer) -> DynamicRemoteOptions:
execution = cast(bool, bootstrap_options.remote_execution)
cache_read = cast(bool, bootstrap_options.remote_cache_read)
cache_write = cast(bool, bootstrap_options.remote_cache_write)
store_address = cast("str | None", bootstrap_options.remote_store_address)
execution_address = cast("str | None", bootstrap_options.remote_execution_address)
instance_name = cast("str | None", bootstrap_options.remote_instance_name)
Expand All @@ -252,82 +316,94 @@ def from_options(
store_rpc_concurrency = cast(int, bootstrap_options.remote_store_rpc_concurrency)
cache_rpc_concurrency = cast(int, bootstrap_options.remote_cache_rpc_concurrency)
execution_rpc_concurrency = cast(int, bootstrap_options.remote_execution_rpc_concurrency)
return cls(
execution=execution,
cache_read=cache_read,
cache_write=cache_write,
instance_name=instance_name,
store_address=cls._normalize_address(store_address),
execution_address=cls._normalize_address(execution_address),
store_headers=store_headers,
execution_headers=execution_headers,
parallelism=parallelism,
store_rpc_concurrency=store_rpc_concurrency,
cache_rpc_concurrency=cache_rpc_concurrency,
execution_rpc_concurrency=execution_rpc_concurrency,
)

if bootstrap_options.remote_oauth_bearer_token_path:
oauth_token = (
Path(bootstrap_options.remote_oauth_bearer_token_path).resolve().read_text().strip()
@classmethod
def _use_auth_plugin(
cls,
bootstrap_options: OptionValueContainer,
full_options: Options,
env: CompleteEnvironment,
prior_result: AuthPluginResult | None,
) -> tuple[DynamicRemoteOptions, AuthPluginResult | None]:
auth_plugin_result: AuthPluginResult | None = None
if ":" not in bootstrap_options.remote_auth_plugin:
raise OptionsError(
"Invalid value for `[GLOBAL].remote_auth_plugin`: "
f"{bootstrap_options.remote_auth_plugin}. Please use the format "
"`path.to.module:my_func`."
)
if set(oauth_token).intersection({"\n", "\r"}):
raise OptionsError(
f"OAuth bearer token path {bootstrap_options.remote_oauth_bearer_token_path} "
"must not contain multiple lines."
)
token_header = {"authorization": f"Bearer {oauth_token}"}
execution_headers.update(token_header)
store_headers.update(token_header)
execution = cast(bool, bootstrap_options.remote_execution)
cache_read = cast(bool, bootstrap_options.remote_cache_read)
cache_write = cast(bool, bootstrap_options.remote_cache_write)
store_address = cast("str | None", bootstrap_options.remote_store_address)
execution_address = cast("str | None", bootstrap_options.remote_execution_address)
instance_name = cast("str | None", bootstrap_options.remote_instance_name)
execution_headers = cast("dict[str, str]", bootstrap_options.remote_execution_headers)
store_headers = cast("dict[str, str]", bootstrap_options.remote_store_headers)
parallelism = cast(int, bootstrap_options.process_execution_remote_parallelism)
store_rpc_concurrency = cast(int, bootstrap_options.remote_store_rpc_concurrency)
cache_rpc_concurrency = cast(int, bootstrap_options.remote_cache_rpc_concurrency)
execution_rpc_concurrency = cast(int, bootstrap_options.remote_execution_rpc_concurrency)
auth_plugin_path, _, auth_plugin_func = bootstrap_options.remote_auth_plugin.partition(":")
auth_plugin_module = importlib.import_module(auth_plugin_path)
auth_plugin_func = getattr(auth_plugin_module, auth_plugin_func)
auth_plugin_result = cast(
AuthPluginResult,
auth_plugin_func(
initial_execution_headers=execution_headers,
initial_store_headers=store_headers,
options=full_options,
env=dict(env),
prior_result=prior_result,
),
)
plugin_name = auth_plugin_result.plugin_name or bootstrap_options.remote_auth_plugin
if not auth_plugin_result.is_available:
# NB: This is debug because we expect plugins to log more informative messages.
logger.debug(
f"Disabling remote caching and remote execution because authentication was not available via the plugin {plugin_name} (from `[GLOBAL].remote_auth_plugin`)."
)
return cls.disabled(), None

auth_plugin_result: AuthPluginResult | None = None
if bootstrap_options.remote_auth_plugin:
if ":" not in bootstrap_options.remote_auth_plugin:
raise OptionsError(
"Invalid value for `[GLOBAL].remote_auth_plugin`: "
f"{bootstrap_options.remote_auth_plugin}. Please use the format "
f"`path.to.module:my_func`."
logger.debug(
f"`[GLOBAL].remote_auth_plugin` {plugin_name} succeeded. Remote caching/execution will be attempted."
)
execution_headers = auth_plugin_result.execution_headers
store_headers = auth_plugin_result.store_headers
plugin_provided_opt_log = "Setting `[GLOBAL].remote_{opt}` is not needed and will be ignored since it is provided by the auth plugin: {plugin_name}."
if auth_plugin_result.instance_name is not None:
if instance_name is not None:
logger.warning(
plugin_provided_opt_log.format(opt="instance_name", plugin_name=plugin_name)
)
auth_plugin_path, auth_plugin_func = bootstrap_options.remote_auth_plugin.split(":")
auth_plugin_module = importlib.import_module(auth_plugin_path)
auth_plugin_func = getattr(auth_plugin_module, auth_plugin_func)
auth_plugin_result = cast(
AuthPluginResult,
auth_plugin_func(
initial_execution_headers=execution_headers,
initial_store_headers=store_headers,
options=full_options,
env=dict(env),
prior_result=prior_result,
),
)
plugin_name = auth_plugin_result.plugin_name or bootstrap_options.remote_auth_plugin
if not auth_plugin_result.is_available:
# NB: This is debug because we expect plugins to log more informative messages.
logger.debug(
"Disabling remote caching and remote execution because authentication was not "
f"available via the plugin {plugin_name} (from `[GLOBAL].remote_auth_plugin`)."
instance_name = auth_plugin_result.instance_name
if auth_plugin_result.store_address is not None:
if store_address is not None:
logger.warning(
plugin_provided_opt_log.format(opt="store_address", plugin_name=plugin_name)
)
execution = False
cache_read = False
cache_write = False
else:
logger.debug(
f"`[GLOBAL].remote_auth_plugin` {plugin_name} succeeded. Remote caching/execution will be attempted."
store_address = auth_plugin_result.store_address
if auth_plugin_result.execution_address is not None:
if execution_address is not None:
logger.warning(
plugin_provided_opt_log.format(opt="execution_address", plugin_name=plugin_name)
)
execution_headers = auth_plugin_result.execution_headers
store_headers = auth_plugin_result.store_headers
plugin_provided_opt_log = "Setting `[GLOBAL].remote_{opt}` is not needed and will be ignored since it is provided by the auth plugin: {plugin_name}."
if auth_plugin_result.instance_name is not None:
if instance_name is not None:
logger.warning(
plugin_provided_opt_log.format(
opt="instance_name", plugin_name=plugin_name
)
)
instance_name = auth_plugin_result.instance_name
if auth_plugin_result.store_address is not None:
if store_address is not None:
logger.warning(
plugin_provided_opt_log.format(
opt="store_address", plugin_name=plugin_name
)
)
store_address = auth_plugin_result.store_address
if auth_plugin_result.execution_address is not None:
if execution_address is not None:
logger.warning(
plugin_provided_opt_log.format(
opt="execution_address", plugin_name=plugin_name
)
)
execution_address = auth_plugin_result.execution_address
execution_address = auth_plugin_result.execution_address

opts = cls(
execution=execution,
cache_read=cache_read,
Expand Down
6 changes: 5 additions & 1 deletion src/python/pants/option/global_options_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,18 @@ def auth_func(initial_execution_headers, initial_store_headers, options, **kwarg
}
assert opts.execution_headers == {"exec": "xyz", "foo": "baz"}
assert opts.cache_read is True
assert opts.cache_write is False
assert opts.execution is False
assert opts.instance_name == "custom_instance"
# Note that the grpc:// prefix will be converted to http://.
assert opts.store_address == "http://custom_store"
assert opts.execution_address == "http://custom_exec"

opts = compute_options("UNAVAILABLE", tmp_path)
assert opts.cache_read is False
assert opts.instance_name == "main"
assert opts.cache_write is False
assert opts.execution is False
assert opts.instance_name is None


def test_execution_options_remote_addresses() -> None:
Expand Down

0 comments on commit d089127

Please sign in to comment.