Skip to content

Commit

Permalink
add retries for az token (with clearing cache) (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
NikitaYurasov authored Aug 20, 2024
1 parent 81d69e2 commit 70e1545
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 5 deletions.
2 changes: 1 addition & 1 deletion dbxio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from dbxio.utils import * # noqa: F403
from dbxio.volume import * # noqa: F403

__version__ = '0.4.0' # single source of truth
__version__ = '0.4.1' # single source of truth
3 changes: 3 additions & 0 deletions dbxio/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ class DbxIOClient:

session_configuration: Optional[Dict[str, Any]] = None

def clear_cache(self):
self.credential_provider.clear_cache()

@classmethod
def from_cluster_settings(
cls,
Expand Down
17 changes: 17 additions & 0 deletions dbxio/core/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def get_credentials(
) -> ClusterCredentials:
raise NotImplementedError

@abstractmethod
def clear_cache(self):
raise NotImplementedError

@classmethod
def from_semi_configured_credentials(
cls,
Expand Down Expand Up @@ -100,6 +104,9 @@ class ClusterEnvAuthProvider(BaseAuthProvider):

_cache: TTLCache = attrs.Factory(lambda: TTLCache(maxsize=1024, ttl=60 * 15))

def clear_cache(self):
self._cache.clear()

@cachedmethod(lambda self: self._cache)
def get_credentials(self) -> ClusterCredentials:
try:
Expand All @@ -125,6 +132,9 @@ class ClusterAirflowAuthProvider(BaseAuthProvider):

_cache: TTLCache = attrs.Factory(lambda: TTLCache(maxsize=1024, ttl=60 * 15))

def clear_cache(self):
self._cache.clear()

@cachedmethod(lambda self: self._cache)
def get_credentials(
self,
Expand Down Expand Up @@ -175,6 +185,10 @@ def __init__(
), 'semi_configured_credentials must be provided if not lazy'
self.ensure_set_auth_provider()

def clear_cache(self):
if self._successful_provider is not None:
self._successful_provider.clear_cache()

def ensure_set_auth_provider(self) -> None:
for provider_type in self._chain:
try:
Expand Down Expand Up @@ -212,6 +226,9 @@ class BareAuthProvider(BaseAuthProvider):

semi_configured_credentials: None = attrs.field(default=None, init=False)

def clear_cache(self):
pass

@cache
def get_credentials(self, **kwargs) -> ClusterCredentials:
return ClusterCredentials(
Expand Down
20 changes: 19 additions & 1 deletion dbxio/delta/table_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,22 @@
from dbxio.blobs.block_upload import upload_file
from dbxio.blobs.parquet import create_pa_table, create_tmp_parquet, pa_table2parquet
from dbxio.core.cloud.client.object_storage import ObjectStorageClient
from dbxio.core.exceptions import ReadDataError
from dbxio.delta.parsers import infer_schema
from dbxio.delta.table import Table, TableFormat
from dbxio.sql.query import ConstDatabricksQuery
from dbxio.sql.results import _FutureBaseResult
from dbxio.utils.blobs import blobs_registries
from dbxio.utils.logging import get_logger
from dbxio.utils.retries import dbxio_retry

if TYPE_CHECKING:
from dbxio.core import DbxIOClient

logger = get_logger()


@dbxio_retry
def exists_table(table: Union[str, Table], client: 'DbxIOClient') -> bool:
"""
Checks if table exists in the catalog. Tries to read one record from the table.
Expand All @@ -32,10 +35,11 @@ def exists_table(table: Union[str, Table], client: 'DbxIOClient') -> bool:
try:
next(read_table(table, limit_records=1, client=client))
return True
except ServerOperationError:
except (ServerOperationError, ReadDataError):
return False


@dbxio_retry
def create_table(table: Union[str, Table], client: 'DbxIOClient') -> _FutureBaseResult:
"""
Creates a table in the catalog.
Expand All @@ -57,6 +61,7 @@ def create_table(table: Union[str, Table], client: 'DbxIOClient') -> _FutureBase
return client.sql(query)


@dbxio_retry
def drop_table(table: Union[str, Table], client: 'DbxIOClient', force: bool = False) -> _FutureBaseResult:
"""
Drops a table from the catalog.
Expand All @@ -70,6 +75,7 @@ def drop_table(table: Union[str, Table], client: 'DbxIOClient', force: bool = Fa
return client.sql(drop_sql)


@dbxio_retry
def read_table(
table: Union[str, Table],
client: 'DbxIOClient',
Expand Down Expand Up @@ -102,6 +108,7 @@ def read_table(
yield record


@dbxio_retry
def save_table_to_files(
table: Union[str, Table],
client: 'DbxIOClient',
Expand All @@ -121,6 +128,7 @@ def save_table_to_files(
return client.sql_to_files(sql_read_query, results_path=results_path, max_concurrency=max_concurrency)


@dbxio_retry
def write_table(
table: Union[str, Table],
new_records: Union[Iterator[Dict], List[Dict]],
Expand Down Expand Up @@ -172,6 +180,7 @@ def write_table(
return client.sql(_sql_query)


@dbxio_retry
def copy_into_table(
client: 'DbxIOClient',
table: Table,
Expand Down Expand Up @@ -199,6 +208,7 @@ def copy_into_table(
client.sql(sql_copy_into_query).wait()


@dbxio_retry
def bulk_write_table(
table: Union[str, Table],
new_records: Union[Iterator[Dict], List[Dict]],
Expand Down Expand Up @@ -248,6 +258,7 @@ def bulk_write_table(
)


@dbxio_retry
def bulk_write_local_files(
table: Table,
path: str,
Expand Down Expand Up @@ -304,6 +315,7 @@ def bulk_write_local_files(
)


@dbxio_retry
def merge_table(
table: 'Union[str , Table]',
new_records: 'Union[Iterator[Dict] , List[Dict]]',
Expand Down Expand Up @@ -342,6 +354,7 @@ def merge_table(
drop_table(tmp_table, client=client, force=True).wait()


@dbxio_retry
def set_comment_on_table(
table: 'Union[str , Table]',
comment: Union[str, None],
Expand All @@ -361,13 +374,15 @@ def set_comment_on_table(
return client.sql(set_comment_query)


@dbxio_retry
def unset_comment_on_table(table: 'Union[str , Table]', client: 'DbxIOClient') -> _FutureBaseResult:
"""
Unsets the comment on a table.
"""
return set_comment_on_table(table=table, comment=None, client=client)


@dbxio_retry
def get_comment_on_table(table: 'Union[str , Table]', client: 'DbxIOClient') -> Union[str, None]:
"""
Returns the comment on a table.
Expand All @@ -390,6 +405,7 @@ def get_comment_on_table(table: 'Union[str , Table]', client: 'DbxIOClient') ->
return None


@dbxio_retry
def set_tags_on_table(table: 'Union[str , Table]', tags: dict[str, str], client: 'DbxIOClient') -> _FutureBaseResult:
"""
Sets tags on a table.
Expand All @@ -406,6 +422,7 @@ def set_tags_on_table(table: 'Union[str , Table]', tags: dict[str, str], client:
return client.sql(set_tags_query)


@dbxio_retry
def unset_tags_on_table(table: 'Union[str , Table]', tags: list[str], client: 'DbxIOClient') -> _FutureBaseResult:
"""
Unsets tags on a table.
Expand All @@ -421,6 +438,7 @@ def unset_tags_on_table(table: 'Union[str , Table]', tags: list[str], client: 'D
return client.sql(unset_tags_query)


@dbxio_retry
def get_tags_on_table(table: 'Union[str , Table]', client: 'DbxIOClient') -> dict[str, str]:
"""
Returns the tags on a table.
Expand Down
2 changes: 2 additions & 0 deletions dbxio/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
)
from dbxio.utils.http import get_session
from dbxio.utils.logging import get_logger
from dbxio.utils.retries import dbxio_retry

__all__ = [
'ClusterType',
Expand All @@ -21,4 +22,5 @@
'blobs_gc',
'get_session',
'get_logger',
'dbxio_retry',
]
31 changes: 31 additions & 0 deletions dbxio/utils/retries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from databricks.sdk.errors.platform import PermissionDenied
from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt, wait_exponential


def _clear_client_cache(call_state: RetryCallState) -> None:
"""
Gets all argument of the function from retry state, finds the client and clears its cache.
"""
if call_state.attempt_number == 1:
# Do not clear cache on the first attempt
return

from dbxio.core.client import DbxIOClient

for arg in call_state.args:
if isinstance(arg, DbxIOClient):
arg.clear_cache()
return
for arg in call_state.kwargs.values():
if isinstance(arg, DbxIOClient):
arg.clear_cache()
return


dbxio_retry = retry(
stop=stop_after_attempt(7),
wait=wait_exponential(multiplier=1),
retry=retry_if_exception_type((PermissionDenied,)),
reraise=True,
before=_clear_client_cache,
)
Loading

0 comments on commit 70e1545

Please sign in to comment.