Skip to content

Commit

Permalink
Add timeout option to gcs hook methods. (apache#13156)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmcarp authored Dec 24, 2020
1 parent b600dfd commit 323084e
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 14 deletions.
30 changes: 24 additions & 6 deletions airflow/providers/google/cloud/hooks/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
RT = TypeVar('RT') # pylint: disable=invalid-name
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name

# Use default timeout from google-cloud-storage
DEFAULT_TIMEOUT = 60


def _fallback_object_url_to_object_name_and_bucket_name(
object_url_keyword_arg_name='object_url',
Expand Down Expand Up @@ -257,7 +260,12 @@ def rewrite(
)

def download(
self, object_name: str, bucket_name: Optional[str], filename: Optional[str] = None
self,
object_name: str,
bucket_name: Optional[str],
filename: Optional[str] = None,
chunk_size: Optional[int] = None,
timeout: Optional[int] = DEFAULT_TIMEOUT,
) -> Union[str, bytes]:
"""
Downloads a file from Google Cloud Storage.
Expand All @@ -273,16 +281,20 @@ def download(
:type object_name: str
:param filename: If set, a local file path where the file should be written to.
:type filename: str
:param chunk_size: Blob chunk size.
:type chunk_size: int
:param timeout: Request timeout in seconds.
:type timeout: int
"""
# TODO: future improvement check file size before downloading,
# to check for local space availability

client = self.get_conn()
bucket = client.bucket(bucket_name)
blob = bucket.blob(blob_name=object_name)
blob = bucket.blob(blob_name=object_name, chunk_size=chunk_size)

if filename:
blob.download_to_filename(filename)
blob.download_to_filename(filename, timeout=timeout)
self.log.info('File downloaded to %s', filename)
return filename
else:
Expand Down Expand Up @@ -359,6 +371,8 @@ def upload(
mime_type: Optional[str] = None,
gzip: bool = False,
encoding: str = 'utf-8',
chunk_size: Optional[int] = None,
timeout: Optional[int] = DEFAULT_TIMEOUT,
) -> None:
"""
Uploads a local file or file data as string or bytes to Google Cloud Storage.
Expand All @@ -377,10 +391,14 @@ def upload(
:type gzip: bool
:param encoding: bytes encoding for file data if provided as string
:type encoding: str
:param chunk_size: Blob chunk size.
:type chunk_size: int
:param timeout: Request timeout in seconds.
:type timeout: int
"""
client = self.get_conn()
bucket = client.bucket(bucket_name)
blob = bucket.blob(blob_name=object_name)
blob = bucket.blob(blob_name=object_name, chunk_size=chunk_size)
if filename and data:
raise ValueError(
"'filename' and 'data' parameter provided. Please "
Expand All @@ -398,7 +416,7 @@ def upload(
shutil.copyfileobj(f_in, f_out)
filename = filename_gz

blob.upload_from_filename(filename=filename, content_type=mime_type)
blob.upload_from_filename(filename=filename, content_type=mime_type, timeout=timeout)
if gzip:
os.remove(filename)
self.log.info('File %s uploaded to %s in %s bucket', filename, object_name, bucket_name)
Expand All @@ -412,7 +430,7 @@ def upload(
with gz.GzipFile(fileobj=out, mode="w") as f:
f.write(data)
data = out.getvalue()
blob.upload_from_string(data, content_type=mime_type)
blob.upload_from_string(data, content_type=mime_type, timeout=timeout)
self.log.info('Data stream uploaded to %s in %s bucket', object_name, bucket_name)
else:
raise ValueError("'filename' and 'data' parameter missing. One is required to upload to gcs.")
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version
'google-cloud-secret-manager>=0.2.0,<2.0.0',
'google-cloud-spanner>=1.10.0,<2.0.0',
'google-cloud-speech>=0.36.3,<2.0.0',
'google-cloud-storage>=1.16,<2.0.0',
'google-cloud-storage>=1.30,<2.0.0',
'google-cloud-tasks>=1.2.1,<2.0.0',
'google-cloud-texttospeech>=0.4.0,<2.0.0',
'google-cloud-translate>=1.5.0,<2.0.0',
Expand Down
14 changes: 7 additions & 7 deletions tests/providers/google/cloud/hooks/test_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ def test_download_to_file(self, mock_service):
)

self.assertEqual(response, test_file)
download_filename_method.assert_called_once_with(test_file)
download_filename_method.assert_called_once_with(test_file, timeout=60)

@mock.patch(GCS_STRING.format('NamedTemporaryFile'))
@mock.patch(GCS_STRING.format('GCSHook.get_conn'))
Expand All @@ -697,7 +697,7 @@ def test_provide_file(self, mock_service, mock_temp_file):
with self.gcs_hook.provide_file(bucket_name=test_bucket, object_name=test_object) as response:

self.assertEqual(test_file, response.name)
download_filename_method.assert_called_once_with(test_file)
download_filename_method.assert_called_once_with(test_file, timeout=60)
mock_temp_file.assert_has_calls(
[
mock.call(suffix='test_object'),
Expand Down Expand Up @@ -762,7 +762,7 @@ def test_upload_file(self, mock_service):
self.gcs_hook.upload(test_bucket, test_object, filename=self.testfile.name)

upload_method.assert_called_once_with(
filename=self.testfile.name, content_type='application/octet-stream'
filename=self.testfile.name, content_type='application/octet-stream', timeout=60
)

@mock.patch(GCS_STRING.format('GCSHook.get_conn'))
Expand All @@ -782,7 +782,7 @@ def test_upload_data_str(self, mock_service):

self.gcs_hook.upload(test_bucket, test_object, data=self.testdata_str)

upload_method.assert_called_once_with(self.testdata_str, content_type='text/plain')
upload_method.assert_called_once_with(self.testdata_str, content_type='text/plain', timeout=60)

@mock.patch(GCS_STRING.format('GCSHook.get_conn'))
def test_upload_data_bytes(self, mock_service):
Expand All @@ -793,7 +793,7 @@ def test_upload_data_bytes(self, mock_service):

self.gcs_hook.upload(test_bucket, test_object, data=self.testdata_bytes)

upload_method.assert_called_once_with(self.testdata_bytes, content_type='text/plain')
upload_method.assert_called_once_with(self.testdata_bytes, content_type='text/plain', timeout=60)

@mock.patch(GCS_STRING.format('BytesIO'))
@mock.patch(GCS_STRING.format('gz.GzipFile'))
Expand All @@ -812,7 +812,7 @@ def test_upload_data_str_gzip(self, mock_service, mock_gzip, mock_bytes_io):
byte_str = bytes(self.testdata_str, encoding)
mock_gzip.assert_called_once_with(fileobj=mock_bytes_io.return_value, mode="w")
gzip_ctx.write.assert_called_once_with(byte_str)
upload_method.assert_called_once_with(data, content_type='text/plain')
upload_method.assert_called_once_with(data, content_type='text/plain', timeout=60)

@mock.patch(GCS_STRING.format('BytesIO'))
@mock.patch(GCS_STRING.format('gz.GzipFile'))
Expand All @@ -829,7 +829,7 @@ def test_upload_data_bytes_gzip(self, mock_service, mock_gzip, mock_bytes_io):

mock_gzip.assert_called_once_with(fileobj=mock_bytes_io.return_value, mode="w")
gzip_ctx.write.assert_called_once_with(self.testdata_bytes)
upload_method.assert_called_once_with(data, content_type='text/plain')
upload_method.assert_called_once_with(data, content_type='text/plain', timeout=60)

@mock.patch(GCS_STRING.format('GCSHook.get_conn'))
def test_upload_exceptions(self, mock_service):
Expand Down

0 comments on commit 323084e

Please sign in to comment.