Skip to content

Commit

Permalink
Make AWSCredentialProvider injectable for GlueClient
Browse files Browse the repository at this point in the history
  • Loading branch information
Praveen2112 committed Mar 24, 2022
1 parent 33056ce commit f7e0a0d
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 50 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.hive.metastore.glue;

import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.AWSStaticCredentialsProvider;
import com.amazonaws.auth.BasicAWSCredentials;
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.auth.STSAssumeRoleSessionCredentialsProvider;

import javax.inject.Inject;
import javax.inject.Provider;

import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

public class GlueCredentialsProvider
implements Provider<AWSCredentialsProvider>
{
private final AWSCredentialsProvider credentialsProvider;

@Inject
public GlueCredentialsProvider(GlueHiveMetastoreConfig config)
{
requireNonNull(config, "config is null");
if (config.getAwsCredentialsProvider().isPresent()) {
this.credentialsProvider = getCustomAWSCredentialsProvider(config.getAwsCredentialsProvider().get());
}
else {
AWSCredentialsProvider provider;
if (config.getAwsAccessKey().isPresent() && config.getAwsSecretKey().isPresent()) {
provider = new AWSStaticCredentialsProvider(
new BasicAWSCredentials(config.getAwsAccessKey().get(), config.getAwsSecretKey().get()));
}
else {
provider = DefaultAWSCredentialsProviderChain.getInstance();
}
if (config.getIamRole().isPresent()) {
provider = new STSAssumeRoleSessionCredentialsProvider
.Builder(config.getIamRole().get(), "trino-session")
.withExternalId(config.getExternalId().orElse(null))
.withLongLivedCredentialsProvider(provider)
.build();
}
this.credentialsProvider = provider;
}
}

@Override
public AWSCredentialsProvider get()
{
return credentialsProvider;
}

private static AWSCredentialsProvider getCustomAWSCredentialsProvider(String providerClass)
{
try {
Object instance = Class.forName(providerClass).getConstructor().newInstance();
if (!(instance instanceof AWSCredentialsProvider)) {
throw new RuntimeException("Invalid credentials provider class: " + instance.getClass().getName());
}
return (AWSCredentialsProvider) instance;
}
catch (ReflectiveOperationException e) {
throw new RuntimeException(format("Error creating an instance of %s", providerClass), e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@
import com.amazonaws.AmazonWebServiceRequest;
import com.amazonaws.ClientConfiguration;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.AWSStaticCredentialsProvider;
import com.amazonaws.auth.BasicAWSCredentials;
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.auth.STSAssumeRoleSessionCredentialsProvider;
import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration;
import com.amazonaws.handlers.AsyncHandler;
import com.amazonaws.handlers.RequestHandler2;
Expand Down Expand Up @@ -145,7 +141,6 @@
import static io.trino.spi.StandardErrorCode.ALREADY_EXISTS;
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
import static io.trino.spi.security.PrincipalType.USER;
import static java.lang.String.format;
import static java.util.Comparator.comparing;
import static java.util.Objects.requireNonNull;
import static java.util.function.UnaryOperator.identity;
Expand Down Expand Up @@ -184,15 +179,17 @@ public class GlueHiveMetastore
public GlueHiveMetastore(
HdfsEnvironment hdfsEnvironment,
GlueHiveMetastoreConfig glueConfig,
AWSCredentialsProvider credentialsProvider,
@ForGlueHiveMetastore Executor partitionsReadExecutor,
GlueColumnStatisticsProviderFactory columnStatisticsProviderFactory,
@ForGlueHiveMetastore Optional<RequestHandler2> requestHandler,
@ForGlueHiveMetastore Predicate<com.amazonaws.services.glue.model.Table> tableFilter)
{
requireNonNull(glueConfig, "glueConfig is null");
requireNonNull(credentialsProvider, "credentialsProvider is null");
this.hdfsEnvironment = requireNonNull(hdfsEnvironment, "hdfsEnvironment is null");
this.hdfsContext = new HdfsContext(ConnectorIdentity.ofUser(DEFAULT_METASTORE_USER));
this.glueClient = createAsyncGlueClient(glueConfig, requestHandler, stats.newRequestMetricsCollector());
this.glueClient = createAsyncGlueClient(glueConfig, credentialsProvider, requestHandler, stats.newRequestMetricsCollector());
this.defaultDir = glueConfig.getDefaultWarehouseDir();
this.catalogId = glueConfig.getCatalogId().orElse(null);
this.partitionSegments = glueConfig.getPartitionSegments();
Expand All @@ -202,7 +199,7 @@ public GlueHiveMetastore(
this.columnStatisticsProvider = columnStatisticsProviderFactory.createGlueColumnStatisticsProvider(glueClient, stats);
}

public static AWSGlueAsync createAsyncGlueClient(GlueHiveMetastoreConfig config, Optional<RequestHandler2> requestHandler, RequestMetricCollector metricsCollector)
public static AWSGlueAsync createAsyncGlueClient(GlueHiveMetastoreConfig config, AWSCredentialsProvider credentialsProvider, Optional<RequestHandler2> requestHandler, RequestMetricCollector metricsCollector)
{
ClientConfiguration clientConfig = new ClientConfiguration()
.withMaxConnections(config.getMaxGlueConnections())
Expand All @@ -226,48 +223,11 @@ else if (config.getPinGlueClientToCurrentRegion()) {
asyncGlueClientBuilder.setRegion(getCurrentRegionFromEC2Metadata().getName());
}

asyncGlueClientBuilder.setCredentials(getAwsCredentialsProvider(config));
asyncGlueClientBuilder.setCredentials(credentialsProvider);

return asyncGlueClientBuilder.build();
}

private static AWSCredentialsProvider getAwsCredentialsProvider(GlueHiveMetastoreConfig config)
{
if (config.getAwsCredentialsProvider().isPresent()) {
return getCustomAWSCredentialsProvider(config.getAwsCredentialsProvider().get());
}
AWSCredentialsProvider provider;
if (config.getAwsAccessKey().isPresent() && config.getAwsSecretKey().isPresent()) {
provider = new AWSStaticCredentialsProvider(
new BasicAWSCredentials(config.getAwsAccessKey().get(), config.getAwsSecretKey().get()));
}
else {
provider = DefaultAWSCredentialsProviderChain.getInstance();
}
if (config.getIamRole().isPresent()) {
provider = new STSAssumeRoleSessionCredentialsProvider
.Builder(config.getIamRole().get(), "trino-session")
.withExternalId(config.getExternalId().orElse(null))
.withLongLivedCredentialsProvider(provider)
.build();
}
return provider;
}

private static AWSCredentialsProvider getCustomAWSCredentialsProvider(String providerClass)
{
try {
Object instance = Class.forName(providerClass).getConstructor().newInstance();
if (!(instance instanceof AWSCredentialsProvider)) {
throw new RuntimeException("Invalid credentials provider class: " + instance.getClass().getName());
}
return (AWSCredentialsProvider) instance;
}
catch (ReflectiveOperationException e) {
throw new RuntimeException(format("Error creating an instance of %s", providerClass), e);
}
}

public GlueMetastoreStats getStats()
{
return stats;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.plugin.hive.metastore.glue;

import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.handlers.RequestHandler2;
import com.amazonaws.services.glue.model.Table;
import com.google.inject.Binder;
Expand Down Expand Up @@ -48,6 +49,7 @@ protected void setup(Binder binder)
{
configBinder(binder).bindConfig(GlueHiveMetastoreConfig.class);
configBinder(binder).bindConfig(HiveConfig.class);
binder.bind(AWSCredentialsProvider.class).toProvider(GlueCredentialsProvider.class).in(Scopes.SINGLETON);
newOptionalBinder(binder, Key.get(RequestHandler2.class, ForGlueHiveMetastore.class));

newOptionalBinder(binder, Key.get(new TypeLiteral<Predicate<Table>>() {}, ForGlueHiveMetastore.class))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.plugin.hive.metastore.glue;

import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.services.glue.AWSGlueAsync;
import com.amazonaws.services.glue.AWSGlueAsyncClientBuilder;
import com.amazonaws.services.glue.model.CreateTableRequest;
Expand Down Expand Up @@ -217,6 +218,7 @@ protected HiveMetastore createMetastore(File tempDir, HiveIdentity identity)
return new GlueHiveMetastore(
HDFS_ENVIRONMENT,
glueConfig,
DefaultAWSCredentialsProviderChain.getInstance(),
executor,
new DefaultGlueColumnStatisticsProviderFactory(glueConfig, executor, executor),
Optional.empty(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.plugin.iceberg.catalog.glue;

import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.services.glue.AWSGlueAsync;
import io.trino.plugin.hive.HdfsEnvironment.HdfsContext;
import io.trino.plugin.hive.metastore.glue.GlueHiveMetastoreConfig;
Expand All @@ -38,12 +39,13 @@ public class GlueIcebergTableOperationsProvider
private final GlueMetastoreStats stats;

@Inject
public GlueIcebergTableOperationsProvider(FileIoProvider fileIoProvider, GlueMetastoreStats stats, GlueHiveMetastoreConfig glueConfig)
public GlueIcebergTableOperationsProvider(FileIoProvider fileIoProvider, GlueMetastoreStats stats, GlueHiveMetastoreConfig glueConfig, AWSCredentialsProvider credentialsProvider)
{
this.fileIoProvider = requireNonNull(fileIoProvider, "fileIoProvider is null");
this.stats = requireNonNull(stats, "stats is null");
requireNonNull(glueConfig, "glueConfig is null");
this.glueClient = createAsyncGlueClient(glueConfig, Optional.empty(), stats.newRequestMetricsCollector());
requireNonNull(credentialsProvider, "credentialsProvider is null");
this.glueClient = createAsyncGlueClient(glueConfig, credentialsProvider, Optional.empty(), stats.newRequestMetricsCollector());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
*/
package io.trino.plugin.iceberg.catalog.glue;

import com.amazonaws.auth.AWSCredentialsProvider;
import com.google.inject.Binder;
import com.google.inject.Scopes;
import io.airlift.configuration.AbstractConfigurationAwareModule;
import io.trino.plugin.hive.metastore.glue.GlueCredentialsProvider;
import io.trino.plugin.hive.metastore.glue.GlueHiveMetastoreConfig;
import io.trino.plugin.hive.metastore.glue.GlueMetastoreStats;
import io.trino.plugin.iceberg.catalog.IcebergTableOperationsProvider;
Expand All @@ -32,6 +34,7 @@ protected void setup(Binder binder)
{
configBinder(binder).bindConfig(GlueHiveMetastoreConfig.class);
binder.bind(GlueMetastoreStats.class).in(Scopes.SINGLETON);
binder.bind(AWSCredentialsProvider.class).toProvider(GlueCredentialsProvider.class).in(Scopes.SINGLETON);
binder.bind(IcebergTableOperationsProvider.class).to(GlueIcebergTableOperationsProvider.class).in(Scopes.SINGLETON);
binder.bind(TrinoCatalogFactory.class).to(TrinoGlueCatalogFactory.class).in(Scopes.SINGLETON);
newExporter(binder).export(TrinoCatalogFactory.class).withGeneratedName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.plugin.iceberg.catalog.glue;

import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.services.glue.AWSGlueAsync;
import io.trino.plugin.hive.HdfsEnvironment;
import io.trino.plugin.hive.metastore.glue.GlueHiveMetastoreConfig;
Expand Down Expand Up @@ -48,6 +49,7 @@ public TrinoGlueCatalogFactory(
HdfsEnvironment hdfsEnvironment,
IcebergTableOperationsProvider tableOperationsProvider,
GlueHiveMetastoreConfig glueConfig,
AWSCredentialsProvider credentialsProvider,
IcebergConfig icebergConfig,
GlueMetastoreStats stats)
{
Expand All @@ -56,7 +58,8 @@ public TrinoGlueCatalogFactory(
requireNonNull(glueConfig, "glueConfig is null");
checkArgument(glueConfig.getCatalogId().isEmpty(), "catalogId configuration is not supported");
this.defaultSchemaLocation = glueConfig.getDefaultWarehouseDir();
this.glueClient = createAsyncGlueClient(glueConfig, Optional.empty(), stats.newRequestMetricsCollector());
requireNonNull(credentialsProvider, "credentialsProvider is null");
this.glueClient = createAsyncGlueClient(glueConfig, credentialsProvider, Optional.empty(), stats.newRequestMetricsCollector());
requireNonNull(icebergConfig, "icebergConfig is null");
this.isUniqueTableLocation = icebergConfig.isUniqueTableLocation();
this.stats = requireNonNull(stats, "stats is null");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.plugin.iceberg;

import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
Expand Down Expand Up @@ -103,6 +104,7 @@ protected QueryRunner createQueryRunner()
this.glueMetastore = new GlueHiveMetastore(
hdfsEnvironment,
new GlueHiveMetastoreConfig(),
DefaultAWSCredentialsProviderChain.getInstance(),
directExecutor(),
new DefaultGlueColumnStatisticsProviderFactory(new GlueHiveMetastoreConfig(), directExecutor(), directExecutor()),
Optional.empty(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.plugin.iceberg;

import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.services.glue.AWSGlueAsyncClientBuilder;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
Expand Down Expand Up @@ -59,7 +60,7 @@ protected TrinoCatalog createTrinoCatalog(boolean useUniqueTableLocations)
new NoHdfsAuthentication());
return new TrinoGlueCatalog(
hdfsEnvironment,
new GlueIcebergTableOperationsProvider(new HdfsFileIoProvider(hdfsEnvironment), new GlueMetastoreStats(), new GlueHiveMetastoreConfig()),
new GlueIcebergTableOperationsProvider(new HdfsFileIoProvider(hdfsEnvironment), new GlueMetastoreStats(), new GlueHiveMetastoreConfig(), DefaultAWSCredentialsProviderChain.getInstance()),
AWSGlueAsyncClientBuilder.defaultClient(),
new GlueMetastoreStats(),
Optional.empty(),
Expand All @@ -82,7 +83,7 @@ public void testDefaultLocation()
new NoHdfsAuthentication());
TrinoCatalog catalogWithDefaultLocation = new TrinoGlueCatalog(
hdfsEnvironment,
new GlueIcebergTableOperationsProvider(new HdfsFileIoProvider(hdfsEnvironment), new GlueMetastoreStats(), new GlueHiveMetastoreConfig()),
new GlueIcebergTableOperationsProvider(new HdfsFileIoProvider(hdfsEnvironment), new GlueMetastoreStats(), new GlueHiveMetastoreConfig(), DefaultAWSCredentialsProviderChain.getInstance()),
AWSGlueAsyncClientBuilder.defaultClient(),
new GlueMetastoreStats(),
Optional.of(tmpDirectory.toAbsolutePath().toString()),
Expand Down

0 comments on commit f7e0a0d

Please sign in to comment.