Skip to content

Commit

Permalink
fix(aws provider): Use default region provider if region unspecified (v…
Browse files Browse the repository at this point in the history
…ectordotdev#12475)

* fix(aws provider): Make region _actually_ required

Fixes: vectordotdev#12474

Signed-off-by: Jesse Szwedko <[email protected]>
  • Loading branch information
jszwedko authored May 4, 2022
1 parent fdb643b commit ce5777f
Show file tree
Hide file tree
Showing 13 changed files with 76 additions and 33 deletions.
13 changes: 12 additions & 1 deletion src/aws/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::config::ProxyConfig;
use crate::http::{build_proxy_connector, build_tls_connector};
use crate::tls::{MaybeTlsSettings, TlsConfig};
pub use auth::AwsAuthentication;
use aws_config::meta::region::ProvideRegion;
use aws_smithy_async::rt::sleep::{AsyncSleep, Sleep};
use aws_smithy_client::erase::DynConnector;
use aws_smithy_client::SdkError;
Expand Down Expand Up @@ -84,11 +85,21 @@ pub trait ClientBuilder {

pub async fn create_client<T: ClientBuilder>(
auth: &AwsAuthentication,
region: Region,
region: Option<Region>,
endpoint: Option<Endpoint>,
proxy: &ProxyConfig,
tls_options: &Option<TlsConfig>,
) -> crate::Result<T::Client> {
// The default credentials chains will look for a region if not given but we'd like to
// error up front if later SDK calls will fail due to lack of region configuration
let region = match region {
Some(region) => Ok(region),
None => aws_config::default_provider::region::default_provider()
.region()
.await
.ok_or("Could not determine region from Vector configuration or default providers"),
}?;

let mut config_builder =
T::create_config_builder(auth.credentials_provider(region.clone()).await?);

Expand Down
40 changes: 35 additions & 5 deletions src/aws/region.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,21 @@ use std::str::FromStr;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
#[serde(default)]
pub struct RegionOrEndpoint {
pub region: String,
pub region: Option<String>,
pub endpoint: Option<String>,
}

impl RegionOrEndpoint {
pub const fn with_region(region: String) -> Self {
Self {
region,
region: Some(region),
endpoint: None,
}
}

pub fn with_both(region: impl Into<String>, endpoint: impl Into<String>) -> Self {
Self {
region: region.into(),
region: Some(region.into()),
endpoint: Some(endpoint.into()),
}
}
Expand All @@ -34,7 +34,37 @@ impl RegionOrEndpoint {
}
}

pub fn region(&self) -> Region {
Region::new(self.region.clone())
pub fn region(&self) -> Option<Region> {
self.region.clone().map(Region::new)
}
}

#[cfg(test)]
mod tests {
use indoc::indoc;

use super::*;

#[test]
fn optional() {
assert!(toml::from_str::<RegionOrEndpoint>(indoc! {r#"
"#})
.is_ok());
}

#[test]
fn region_optional() {
assert!(toml::from_str::<RegionOrEndpoint>(indoc! {r#"
endpoint = "http://localhost:8080"
"#})
.is_ok());
}

#[test]
fn endpoint_optional() {
assert!(toml::from_str::<RegionOrEndpoint>(indoc! {r#"
region = "us-east-1"
"#})
.is_ok());
}
}
8 changes: 6 additions & 2 deletions src/internal_events/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,23 @@ impl<'a> InternalEvent for EndpointBytesSent<'a> {
#[cfg(feature = "aws-core")]
pub struct AwsBytesSent {
pub byte_size: usize,
pub region: aws_types::region::Region,
pub region: Option<aws_types::region::Region>,
}

#[cfg(feature = "aws-core")]
impl InternalEvent for AwsBytesSent {
fn emit(self) {
let region = self
.region
.as_ref()
.map(|r| r.as_ref().to_string())
.unwrap_or_default();
trace!(
message = "Bytes sent.",
protocol = "https",
byte_size = %self.byte_size,
region = ?self.region,
);
let region = self.region.to_string();
counter!(
"component_sent_bytes_total", self.byte_size as u64,
"protocol" => "https",
Expand Down
2 changes: 1 addition & 1 deletion src/sinks/aws_cloudwatch_logs/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ async fn cloudwatch_healthcheck() {

async fn create_client_test() -> CloudwatchLogsClient {
let auth = AwsAuthentication::test_auth();
let region = Region::new("localstack");
let region = Some(Region::new("localstack"));
let watchlogs_address = watchlogs_address();
let endpoint = Some(Endpoint::immutable(
Uri::from_str(&watchlogs_address).unwrap(),
Expand Down
2 changes: 1 addition & 1 deletion src/sinks/aws_cloudwatch_metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ impl CloudWatchMetricsSinkConfig {
async fn create_client(&self, proxy: &ProxyConfig) -> crate::Result<CloudwatchClient> {
let region = if cfg!(test) {
// Moto (used for mocking AWS) doesn't recognize 'custom' as valid region name
Region::new("us-east-1")
Some(Region::new("us-east-1"))
} else {
self.region.region()
};
Expand Down
2 changes: 1 addition & 1 deletion src/sinks/aws_kinesis_firehose/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ async fn ensure_elasticsearch_domain(domain_name: String) -> String {
aws_sdk_elasticsearch::config::Builder::new()
.credentials_provider(
AwsAuthentication::test_auth()
.credentials_provider(test_region_endpoint().region())
.credentials_provider(test_region_endpoint().region().unwrap())
.await
.unwrap(),
)
Expand Down
2 changes: 1 addition & 1 deletion src/sinks/aws_kinesis_firehose/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use aws_sdk_firehose::{Client as KinesisFirehoseClient, Region};
#[derive(Clone)]
pub struct KinesisService {
pub client: KinesisFirehoseClient,
pub region: Region,
pub region: Option<Region>,
pub stream_name: String,
}

Expand Down
2 changes: 1 addition & 1 deletion src/sinks/aws_kinesis_streams/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use aws_types::region::Region;
pub struct KinesisService {
pub client: KinesisClient,
pub stream_name: String,
pub region: Region,
pub region: Option<Region>,
}

pub struct KinesisResponse {
Expand Down
2 changes: 1 addition & 1 deletion src/sinks/aws_sqs/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ async fn create_test_client() -> SqsClient {
let proxy = ProxyConfig::default();
create_client::<SqsClientBuilder>(
&auth,
Region::new("localstack"),
Some(Region::new("localstack")),
Some(Endpoint::immutable(Uri::from_str(&endpoint).unwrap())),
&proxy,
&None,
Expand Down
3 changes: 2 additions & 1 deletion src/sinks/elasticsearch/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ impl ElasticsearchCommon {
.aws
.as_ref()
.map(|config| config.region())
.ok_or(ParseError::RegionRequired)?
.ok_or(ParseError::RegionRequired)?;

Some(aws.credentials_provider(region).await?)
Expand Down Expand Up @@ -120,7 +121,7 @@ impl ElasticsearchCommon {
metric_config.timezone.unwrap_or_default(),
);

let region = config.aws.as_ref().map(|config| config.region());
let region = config.aws.as_ref().and_then(|config| config.region());

Ok(Self {
http_auth,
Expand Down
4 changes: 2 additions & 2 deletions src/sinks/s3_common/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ impl DriverResponse for S3Response {
#[derive(Clone)]
pub struct S3Service {
client: S3Client,
region: Region,
region: Option<Region>,
}

impl S3Service {
pub const fn new(client: S3Client, region: Region) -> S3Service {
pub const fn new(client: S3Client, region: Option<Region>) -> S3Service {
S3Service { client, region }
}

Expand Down
15 changes: 10 additions & 5 deletions src/sources/aws_s3/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,10 @@ impl AwsS3Config {
multiline: Option<line_agg::Config>,
proxy: &ProxyConfig,
) -> crate::Result<sqs::Ingestor> {
let region = self.region.region();
let region = self
.region
.region()
.ok_or(CreateSqsIngestorError::RegionMissing)?;

let endpoint = self
.region
Expand All @@ -166,7 +169,7 @@ impl AwsS3Config {

let s3_client = create_client::<S3ClientBuilder>(
&self.auth,
region.clone(),
Some(region.clone()),
endpoint.clone(),
proxy,
&self.tls_options,
Expand All @@ -177,7 +180,7 @@ impl AwsS3Config {
Some(ref sqs) => {
let sqs_client = create_client::<SqsClientBuilder>(
&self.auth,
region.clone(),
Some(region.clone()),
endpoint,
proxy,
&sqs.tls_options,
Expand Down Expand Up @@ -211,6 +214,8 @@ enum CreateSqsIngestorError {
Credentials { source: crate::Error },
#[snafu(display("Configuration for `sqs` required when strategy=sqs"))]
ConfigMissing,
#[snafu(display("Region is required"))]
RegionMissing,
#[snafu(display("Endpoint is invalid"))]
InvalidEndpoint,
}
Expand Down Expand Up @@ -747,7 +752,7 @@ mod integration_tests {
async fn s3_client() -> S3Client {
let auth = AwsAuthentication::test_auth();
let region_endpoint = RegionOrEndpoint {
region: "us-east-1".to_owned(),
region: Some("us-east-1".to_owned()),
endpoint: Some(s3_address()),
};
let proxy_config = ProxyConfig::default();
Expand All @@ -765,7 +770,7 @@ mod integration_tests {
async fn sqs_client() -> SqsClient {
let auth = AwsAuthentication::test_auth();
let region_endpoint = RegionOrEndpoint {
region: "us-east-1".to_owned(),
region: Some("us-east-1".to_owned()),
endpoint: Some(s3_address()),
};
let proxy_config = ProxyConfig::default();
Expand Down
14 changes: 3 additions & 11 deletions tests/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -425,17 +425,9 @@ async fn bad_s3_region() {
encoding = "text"
bucket = "asdf"
key_prefix = "logs/"
[sinks.out2]
type = "aws_s3"
inputs = ["in"]
compression = "gzip"
encoding = "text"
bucket = "asdf"
key_prefix = "logs/"
region = "moonbase-alpha"
[sinks.out3]
[sinks.out2]
type = "aws_s3"
inputs = ["in"]
compression = "gzip"
Expand All @@ -444,15 +436,15 @@ async fn bad_s3_region() {
key_prefix = "logs/"
endpoint = "this shouldnt work"
[sinks.out3.batch]
[sinks.out2.batch]
max_bytes = 100000
"#,
Format::Toml,
)
.await
.unwrap_err();

assert_eq!(err, vec!["Sink \"out3\": invalid uri character"])
assert_eq!(err, vec!["Sink \"out2\": invalid uri character"])
}

#[cfg(all(
Expand Down

0 comments on commit ce5777f

Please sign in to comment.