-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathathena_reagent
78 lines (63 loc) · 2.5 KB
/
athena_reagent
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import boto3
from retrying import retry
import pprint
import sys
import os
aws_access_key = os.getenv("AWS_ACCESS_KEY")
aws_secret_access_key = os.getenv("AWS_SECRET_KEY")
region = os.getenv("AWS_REGION")
bucket_name = "s3_bucket_name"
s3_path = "athena_results/"
# staging_dir = 's3://' + bucket_name + '/' + s3_path
DATABASE_NAME = 'your_database_name'
class AthenaManager:
def __init__(self,aws_access_key, aws_secret_access_key,region, bucket_name, s3_path, database):
self.session = boto3.Session(aws_access_key_id=aws_access_key, aws_secret_access_key=aws_secret_access_key)
self.athena = self.session.client('athena', region)
self.s3 = self.session.resource('s3', region)
self.bucket_name = bucket_name
self.s3_path = s3_path
self.staging_dir = 's3://' + bucket_name + '/' + s3_path
self.database = database
@retry(stop_max_attempt_number=10,
wait_exponential_multiplier=30 * 1000,
wait_exponential_max=10 * 60 * 1000)
def poll_status(self, id):
result = self.athena.get_query_execution(
QueryExecutionId=id
)
state = result['QueryExecution']['Status']['State']
if state == 'SUCCEEDED':
return result
elif state == 'FAILED':
return result
else:
raise Exception
def query_to_athena(self, filename):
sql = open(filename, 'r').read()
result = self.athena.start_query_execution(
QueryString=sql,
QueryExecutionContext={
'Database': self.database
},
ResultConfiguration={
'OutputLocation': self.staging_dir
}
)
QueryExecutionId = result['QueryExecutionId']
result = self.poll_status(QueryExecutionId)
# save response
with open(filename + '.log', 'w') as f:
f.write(pprint.pformat(result, indent=4))
# save query result from S3
if result['QueryExecution']['Status']['State'] == 'SUCCEEDED':
s3_key = self.s3_path + QueryExecutionId + '.csv'
local_filename = filename + '.csv'
self.s3.Bucket(self.bucket_name).download_file(s3_key, local_filename)
if __name__ == '__main__':
athena_mgr = AthenaManager(aws_access_key, aws_secret_access_key, region, bucket_name, s3_path, DATABASE_NAME)
for filename in sys.argv[1:]:
try:
athena_mgr.query_to_athena(filename)
except Exception as err:
print(err)