Skip to content

Commit

Permalink
Encode azure secrets manager secret names (zenml-io#760)
Browse files Browse the repository at this point in the history
* Encode azure secrets manager secret names

* Remove unraised error from docstring

* Rename name to key
  • Loading branch information
schustmi authored Jul 12, 2022
1 parent 4c438d8 commit aef19bd
Showing 1 changed file with 24 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# permissions and limitations under the License.
"""Implementation of the Azure Secrets Manager integration."""

from typing import Any, ClassVar, Dict, List
import base64
from typing import Any, ClassVar, List

from azure.identity import DefaultAzureCredential
from azure.keyvault.secrets import SecretClient
Expand All @@ -29,53 +30,15 @@

ZENML_SCHEMA_NAME = "zenml-schema-name"
ZENML_GROUP_KEY = "zenml-group-key"


def prepend_group_name_to_keys(secret: BaseSecretSchema) -> Dict[str, str]:
"""Adds the secret group name to the keys of each secret key-value pair.
This allows using the same key across multiple
secrets.
Args:
secret: The ZenML Secret schema
Returns:
A dictionary with the secret keys prepended with the group name
"""
return {f"{secret.name}-{k}": v for k, v in secret.content.items()}


def remove_group_name_from_key(combined_key_name: str, group_name: str) -> str:
"""Removes the secret group name from the secret key.
Args:
combined_key_name: Full name as it is within the Azure secrets manager
group_name: Group name (the ZenML Secret name)
Returns:
The cleaned key
Raises:
RuntimeError: If the group name is not found in the key name
"""
if combined_key_name.startswith(f"{group_name}-"):
return combined_key_name[len(f"{group_name}-") :]
else:
raise RuntimeError(
f"Key-name `{combined_key_name}` does not have the "
f"prefix `{group_name}`. Key could not be "
f"extracted."
)
ZENML_KEY_NAME = "zenml-key-name"


class AzureSecretsManager(BaseSecretsManager):
"""Class to interact with the Azure secrets manager.
Attributes:
project_id: This is necessary to access the correct Azure project.
The project_id of your Azure project space that contains
the Secret Manager.
key_vault_name: Name of an Azure Key Vault that this secrets manager
will use to store secrets.
"""

key_vault_name: str
Expand All @@ -100,37 +63,15 @@ def register_secret(self, secret: BaseSecretSchema) -> None:
Raises:
SecretExistsError: if the secret already exists
ValueError: if the secret name contains an underscore.
"""
self._ensure_client_connected(self.key_vault_name)

if "_" in secret.name:
raise ValueError(
f"The secret name `{secret.name}` contains an underscore. "
f"This will cause issues with Azure. Please try again."
)

if secret.name in self.get_all_secret_keys():
raise SecretExistsError(
f"A Secret with the name '{secret.name}' already exists."
)

adjusted_content = prepend_group_name_to_keys(secret)

for k, v in adjusted_content.items():
# Create the secret, this only creates an empty secret with the
# supplied name.
azure_secret = self.CLIENT.set_secret(k, v)
self.CLIENT.update_secret_properties(
azure_secret.name,
tags={
ZENML_GROUP_KEY: secret.name,
ZENML_SCHEMA_NAME: secret.TYPE,
},
)

logger.debug("Created created secret: %s", azure_secret.name)
logger.debug("Added value to secret.")
self.update_secret(secret)

def get_secret(self, secret_name: str) -> BaseSecretSchema:
"""Get a secret by its name.
Expand All @@ -151,15 +92,17 @@ def get_secret(self, secret_name: str) -> BaseSecretSchema:
zenml_schema_name = ""

for secret_property in self.CLIENT.list_properties_of_secrets():
response = self.CLIENT.get_secret(secret_property.name)
tags = response.properties.tags
tags = secret_property.tags

if tags and tags.get(ZENML_GROUP_KEY) == secret_name:
secret_key = remove_group_name_from_key(
combined_key_name=response.name, group_name=secret_name
)
secret_key = tags.get(ZENML_KEY_NAME)
if not secret_key:
raise ValueError("Missing secret key tag.")

if secret_key == "name":
raise ValueError("The secret's key cannot be 'name'.")

response = self.CLIENT.get_secret(secret_property.name)
secret_contents[secret_key] = response.value

zenml_schema_name = tags.get(ZENML_SCHEMA_NAME)
Expand All @@ -186,7 +129,7 @@ def get_all_secret_keys(self) -> List[str]:

for secret_property in self.CLIENT.list_properties_of_secrets():
tags = secret_property.tags
if tags:
if tags and ZENML_GROUP_KEY in tags:
set_of_secrets.add(tags.get(ZENML_GROUP_KEY))

return list(set_of_secrets)
Expand All @@ -199,18 +142,24 @@ def update_secret(self, secret: BaseSecretSchema) -> None:
"""
self._ensure_client_connected(self.key_vault_name)

adjusted_content = prepend_group_name_to_keys(secret)
for key, value in secret.content.items():
encoded_key = base64.b64encode(
f"{secret.name}-{key}".encode()
).hex()
azure_secret_name = f"zenml-{encoded_key}"

for k, v in adjusted_content.items():
self.CLIENT.set_secret(k, v)
self.CLIENT.set_secret(azure_secret_name, value)
self.CLIENT.update_secret_properties(
k,
azure_secret_name,
tags={
ZENML_GROUP_KEY: secret.name,
ZENML_KEY_NAME: key,
ZENML_SCHEMA_NAME: secret.TYPE,
},
)

logger.debug("Wrote secret: %s", azure_secret_name)

def delete_secret(self, secret_name: str) -> None:
"""Delete an existing secret. by name.
Expand Down

0 comments on commit aef19bd

Please sign in to comment.