Skip to content

Commit

Permalink
tox -e mypy now works!
Browse files Browse the repository at this point in the history
  • Loading branch information
Aaron Loo committed Dec 2, 2020
1 parent b73c473 commit 93df579
Show file tree
Hide file tree
Showing 39 changed files with 281 additions and 144 deletions.
10 changes: 5 additions & 5 deletions detect_secrets/audit/analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
"""
from collections import defaultdict
from typing import Any
from typing import cast
from typing import Dict
from typing import Tuple

from ..core.plugins.util import get_mapping_from_secret_type_to_class
from ..core.potential_secret import PotentialSecret
Expand All @@ -21,7 +21,7 @@ def calculate_statistics_for_baseline(
"""
secrets = get_baseline_from_file(filename)

aggregator = StatisticsAggregator(**kwargs)
aggregator = StatisticsAggregator()
for _, secret in secrets:
# TODO: gather real secrets?
# TODO: do we need repo_info?
Expand All @@ -36,7 +36,7 @@ def __init__(self) -> None:
'stats': StatisticsCounter,
}

self.data = defaultdict(
self.data: Dict[str, Any] = defaultdict(
lambda: {
key: value()
for key, value in framework.items()
Expand All @@ -55,7 +55,7 @@ def record_secret(self, secret: PotentialSecret) -> None:
counter.unknown += 1

def _get_plugin_counter(self, secret_type: str) -> 'StatisticsCounter':
return self.data[secret_type]['stats']
return cast(StatisticsCounter, self.data[secret_type]['stats'])

def __str__(self) -> str:
raise NotImplementedError
Expand All @@ -77,7 +77,7 @@ def __init__(self) -> None:
self.incorrect: int = 0
self.unknown: int = 0

def __repr__(self) -> Tuple[int, int, int]:
def __repr__(self) -> str:
return (
f'{self.__class__.__name__}(correct={self.correct}, '
'incorrect={self.incorrect}, unknown={self.unknown},)'
Expand Down
22 changes: 13 additions & 9 deletions detect_secrets/audit/common.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
import json
from functools import lru_cache
from typing import Any
from typing import Dict
from typing import cast
from typing import List

from . import io
from ..core import baseline
from ..core import plugins
from ..core.potential_secret import PotentialSecret
from ..core.secrets_collection import SecretsCollection
from ..exceptions import InvalidBaselineError
from ..exceptions import SecretNotFoundOnSpecifiedLineError
from ..plugins.base import BasePlugin
from ..types import SelfAwareCallable
from ..util.inject import get_injectable_variables
from ..util.inject import inject_variables_into_function


def get_baseline_from_file(filename: str) -> Dict[str, Any]:
def get_baseline_from_file(filename: str) -> SecretsCollection:
"""
:raises: InvalidBaselineError
"""
Expand All @@ -38,29 +40,31 @@ def get_raw_secret_from_file(secret: PotentialSecret) -> str:
:raises: SecretNotFoundOnSpecifiedLineError
"""
plugin = plugins.initialize.from_secret_type(secret.type)
plugin = cast(BasePlugin, plugins.initialize.from_secret_type(secret.type))
try:
target_line = open_file(secret.filename)[secret.line_number - 1]
except IndexError:
raise SecretNotFoundOnSpecifiedLineError(secret.line_number)

function = plugin.__class__.analyze_line
if not hasattr(function, 'injectable_variables'):
function.injectable_variables = set(get_injectable_variables(plugin.analyze_line))
function.path = f'{plugin.__class__.__name__}.analyze_line'
function.injectable_variables = set( # type: ignore
get_injectable_variables(plugin.analyze_line),
)
function.path = f'{plugin.__class__.__name__}.analyze_line' # type: ignore

identified_secrets = inject_variables_into_function(
function,
cast(SelfAwareCallable, function),
self=plugin,
filename=secret.filename,
line=target_line,
line_number=secret.line_number, # TODO: this will be optional
enable_eager_search=True,
)

for identified_secret in identified_secrets:
for identified_secret in (identified_secrets or []):
if identified_secret == secret:
return identified_secret.secret_value
return cast(str, identified_secret.secret_value)

raise SecretNotFoundOnSpecifiedLineError(secret.line_number)

Expand Down
4 changes: 3 additions & 1 deletion detect_secrets/audit/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from typing import Generator
from typing import Optional
from typing import Tuple
from typing import Type
from typing import Union

from . import io
from ..core import baseline
Expand Down Expand Up @@ -87,7 +89,7 @@ class RightSecret(Exception):
# This allows us to delay execution of the exception handling, until we had a chance
# to initialize both variables. Either one must at least pass, otherwise the while
# statement will be False.
exception = None
exception: Optional[Union[Type[LeftSecret], Type[RightSecret]]] = None
try:
left_secret = left_secrets[left_index]
except IndexError:
Expand Down
6 changes: 3 additions & 3 deletions detect_secrets/audit/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def print_context(context: SecretContext) -> None:
context.snippet.highlight_line(context.secret.secret_value)
else:
context.snippet.target_line = colorize(context.snippet.target_line, AnsiColor.BOLD)
print_message(context.snippet)
print_message(str(context.snippet))

print_message('-' * 10)

Expand All @@ -48,7 +48,7 @@ def print_secret_not_found(context: SecretContext) -> None:

_print_header(context)

print_message(context.error)
print_message(str(context.error))
print_message('-' * 10)


Expand Down Expand Up @@ -92,7 +92,7 @@ def get_user_decision(
user_input = None
while user_input not in prompter.valid_input:
if user_input:
print('Invalid input.')
print('Invalid input.') # type: ignore # Statement unreachable? Come on mypy...

user_input = input(str(prompter))
if user_input:
Expand Down
5 changes: 3 additions & 2 deletions detect_secrets/core/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
from typing import Any
from typing import Callable
from typing import cast
from typing import Dict
from typing import Union

Expand Down Expand Up @@ -46,7 +47,7 @@ def load_from_file(filename: str) -> Dict[str, Any]:
"""
try:
with open(filename) as f:
return json.loads(f.read())
return cast(Dict[str, Any], json.loads(f.read()))
except (FileNotFoundError, IOError, json.decoder.JSONDecodeError) as e:
raise UnableToReadBaselineError from e

Expand Down Expand Up @@ -99,7 +100,7 @@ def upgrade(baseline: Dict[str, Any]) -> Dict[str, Any]:

new_baseline = {**baseline}
for module in modules:
module.upgrade(new_baseline)
module.upgrade(new_baseline) # type: ignore

new_baseline['version'] = VERSION
return new_baseline
Expand Down
35 changes: 25 additions & 10 deletions detect_secrets/core/log.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
import logging
import sys
from functools import partial
from typing import cast
from typing import Optional


def get_logger(name=None, format_string=None):
def get_logger(name: Optional[str] = None, format_string: Optional[str] = None) -> 'CustomLogger':
"""
:type name: str
:param name: used for declaring log channels.
:type format_string: str|None
:param format_string: for custom formatting
"""
logging.captureWarnings(True)
log = logging.getLogger(name)

# Bind custom method to instance.
# Source: https://stackoverflow.com/a/2982
log.set_debug_level = _set_debug_level.__get__(log)
log.set_debug_level(0)
log.set_debug_level = partial(CustomLogger.set_debug_level, log) # type: ignore
cast(CustomLogger, log).set_debug_level(0)

# Setting up log formats
log.handlers = []
Expand All @@ -30,13 +30,12 @@ def get_logger(name=None, format_string=None):
)
log.addHandler(handler)

return log
return cast(CustomLogger, log)


def _set_debug_level(self, debug_level):
def _set_debug_level(self: logging.Logger, debug_level: int) -> None:
"""
:type debug_level: int, between 0-2
:param debug_level: configure verbosity of log
:param debug_level: between 0-2, configure verbosity of log
"""
mapping = {
0: logging.ERROR,
Expand All @@ -49,4 +48,20 @@ def _set_debug_level(self, debug_level):
)


class CustomLogger(logging.Logger):
def set_debug_level(self, debug_level: int) -> None:
"""
:param debug_level: between 0-2, configure verbosity of log
"""
mapping = {
0: logging.ERROR,
1: logging.INFO,
2: logging.DEBUG,
}

self.setLevel(
mapping[min(debug_level, 2)],
)


log = get_logger()
16 changes: 10 additions & 6 deletions detect_secrets/core/plugins/initialize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Any
from typing import Dict
from typing import Iterable
from typing import List
from typing import Type

from ...settings import get_settings
from ..log import log
Expand All @@ -19,7 +21,7 @@ def from_secret_type(secret_type: str) -> Plugin:
raise TypeError

try:
return plugin_type(**_get_config(plugin_type.__name__))
return plugin_type(**_get_config(plugin_type.__name__)) # type: ignore
except TypeError:
log.error('Unable to initialize plugin!')
raise
Expand All @@ -42,23 +44,25 @@ def from_plugin_classname(classname: str) -> Plugin:
raise TypeError

try:
return plugin_type(**_get_config(classname))
return plugin_type(**_get_config(classname)) # type: ignore
except TypeError:
log.error('Unable to initialize plugin!')
raise


def from_file(filename: str) -> Iterable[Plugin]:
def from_file(filename: str) -> Iterable[Type[Plugin]]:
"""
:raises: FileNotFoundError
:raises: InvalidFile
"""
output = []
output: List[Type[Plugin]] = []
plugin_class: Type[Plugin]
for plugin_class in get_plugins_from_file(filename):
if plugin_class.secret_type in get_mapping_from_secret_type_to_class():
secret_type = plugin_class.secret_type # type: ignore
if secret_type in get_mapping_from_secret_type_to_class():
log.debug(f'Duplicate plugin detected: {plugin_class.__name__}. Skipping...')

get_mapping_from_secret_type_to_class()[plugin_class.secret_type] = plugin_class
get_mapping_from_secret_type_to_class()[secret_type] = plugin_class
output.append(plugin_class)

return output
Expand Down
6 changes: 4 additions & 2 deletions detect_secrets/core/plugins/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from functools import lru_cache
from types import ModuleType
from typing import Any
from typing import cast
from typing import Dict
from typing import Generator
from typing import Type
Expand Down Expand Up @@ -40,19 +41,20 @@ def get_mapping_from_secret_type_to_class() -> Dict[str, Type[Plugin]]:
# Only supporting file schema right now.
filename = config['path'][len('file://'):]
for plugin_class in get_plugins_from_file(filename):
output[plugin_class.secret_type] = plugin_class
output[cast(BasePlugin, plugin_class).secret_type] = plugin_class

return output


def get_plugins_from_file(filename: str) -> Generator[Type[Plugin], None, None]:
plugin_class: Type[Plugin]
for plugin_class in get_plugins_from_module(import_file_as_module(filename)):
yield plugin_class


def get_plugins_from_module(module: ModuleType) -> Generator[Type[Plugin], None, None]:
for plugin_class in import_types_from_module(module, filter=lambda x: not _is_valid_plugin(x)):
yield plugin_class
yield cast(Type[Plugin], plugin_class)


def _is_valid_plugin(attribute: Any) -> bool:
Expand Down
24 changes: 15 additions & 9 deletions detect_secrets/core/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import subprocess
from functools import lru_cache
from importlib import import_module
from typing import cast
from typing import Generator
from typing import IO
from typing import Iterable
Expand Down Expand Up @@ -31,25 +32,30 @@ def get_files_to_scan(*paths: str, should_scan_all_files: bool) -> Generator[str
valid_paths = git.get_tracked_files(git.get_root_directory())
except subprocess.CalledProcessError:
log.warning('Did not detect git repository. Try scanning all files instead.')
return []
yield from []
return

for path in paths:
iterator = [(os.getcwd(), None, [path])] if os.path.isfile(path) else os.walk(path)
iterator = (
cast(List[Tuple], [(os.getcwd(), None, [path])])
if os.path.isfile(path)
else os.walk(path)
)
for path_root, _, filenames in iterator:
for filename in filenames:
path = get_relative_path_if_in_cwd(os.path.join(path_root, filename))
if not path:
relative_path = get_relative_path_if_in_cwd(os.path.join(path_root, filename))
if not relative_path:
# e.g. symbolic links may be pointing outside the root directory
continue

if (
not should_scan_all_files
and path not in valid_paths
and relative_path not in valid_paths
):
# Not a git-tracked file
continue

yield path
yield relative_path


def scan_line(line: str) -> Generator[PotentialSecret, None, None]:
Expand Down Expand Up @@ -168,12 +174,12 @@ def _scan_for_allowlisted_secrets_in_lines(
get_filters.cache_clear()

line_numbers, lines = zip(*lines)
lines = [line.rstrip() for line in lines]
for line_number, line in zip(line_numbers, lines):
line_content = [line.rstrip() for line in lines]
for line_number, line in zip(line_numbers, line_content):
if not is_line_allowlisted(
filename,
line,
context=get_code_snippet(lines, line_number),
context=get_code_snippet(line_content, line_number),
):
continue

Expand Down
Loading

0 comments on commit 93df579

Please sign in to comment.