Skip to content

Commit

Permalink
chore: upgrade mypy and add type guards (apache#16227)
Browse files Browse the repository at this point in the history
  • Loading branch information
villebro authored Aug 14, 2021
1 parent 9b2dffe commit d46dc9a
Show file tree
Hide file tree
Showing 9 changed files with 29 additions and 18 deletions.
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ repos:
hooks:
- id: isort
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.790
rev: v0.910
hooks:
- id: mypy
additional_dependencies: [types-all]
- repo: https://github.com/peterdemin/pip-compile-multi
rev: v2.4.1
hooks:
Expand Down
4 changes: 2 additions & 2 deletions RELEASING/changelog.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,12 +384,12 @@ def change_log(
with open(csv, "w") as csv_file:
log_items = list(logs)
field_names = log_items[0].keys()
writer = lib_csv.DictWriter(
writer = lib_csv.DictWriter( # type: ignore
csv_file,
delimiter=",",
quotechar='"',
quoting=lib_csv.QUOTE_ALL,
fieldnames=field_names,
fieldnames=field_names, # type: ignore
)
writer.writeheader()
for log in logs:
Expand Down
10 changes: 7 additions & 3 deletions scripts/benchmark_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,13 @@ def import_migration_script(filepath: Path) -> ModuleType:
Import migration script as if it were a module.
"""
spec = importlib.util.spec_from_file_location(filepath.stem, filepath)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
return module
if spec:
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
return module
raise Exception(
"No module spec found in location: `{path}`".format(path=str(filepath))
)


def extract_modified_tables(module: ModuleType) -> Set[str]:
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,14 @@ def get_git_sha() -> str:
"simplejson>=3.15.0",
"slackclient==2.5.0", # PINNED! slack changes file upload api in the future versions
"sqlalchemy>=1.3.16, <1.4, !=1.3.21",
"sqlalchemy-utils>=0.36.6,<0.37",
"sqlalchemy-utils>=0.36.6, <0.37",
"sqlparse==0.3.0", # PINNED! see https://github.com/andialbrecht/sqlparse/issues/562
"tabulate==0.8.9",
"typing-extensions>=3.7.4.3,<4", # needed to support typing.Literal on py37
"typing-extensions>=3.10, <4", # needed to support Literal (3.8) and TypeGuard (3.10)
"wtforms-json",
],
extras_require={
"athena": ["pyathena>=1.10.8,<1.11"],
"athena": ["pyathena>=1.10.8, <1.11"],
"bigquery": [
"pandas_gbq>=0.10.0",
"pybigquery>=0.4.10",
Expand Down
2 changes: 1 addition & 1 deletion superset/tasks/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
WEBDRIVER_BASEURL_USER_FRIENDLY = config["WEBDRIVER_BASEURL_USER_FRIENDLY"]

ReportContent = namedtuple(
"EmailContent",
"ReportContent",
[
"body", # email body
"data", # attachments
Expand Down
4 changes: 2 additions & 2 deletions superset/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from flask import Flask
from flask_caching import Cache
from typing_extensions import TypedDict
from typing_extensions import Literal, TypedDict
from werkzeug.wrappers import Response

if TYPE_CHECKING:
Expand Down Expand Up @@ -57,7 +57,7 @@ class AdhocMetricColumn(TypedDict, total=False):
class AdhocMetric(TypedDict, total=False):
aggregate: str
column: Optional[AdhocMetricColumn]
expressionType: str
expressionType: Literal["SIMPLE", "SQL"]
label: Optional[str]
sqlExpression: Optional[str]

Expand Down
4 changes: 2 additions & 2 deletions superset/utils/async_query_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class AsyncQueryManager:

def __init__(self) -> None:
super().__init__()
self._redis: redis.Redis
self._redis: redis.Redis # type: ignore
self._stream_prefix: str = ""
self._stream_limit: Optional[int]
self._stream_limit_firehose: Optional[int]
Expand All @@ -100,7 +100,7 @@ def init_app(self, app: Flask) -> None:
"Please provide a JWT secret at least 32 bytes long"
)

self._redis = redis.Redis( # type: ignore
self._redis = redis.Redis(
**config["GLOBAL_ASYNC_QUERIES_REDIS_CONFIG"], decode_responses=True
)
self._stream_prefix = config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"]
Expand Down
7 changes: 3 additions & 4 deletions superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql.type_api import Variant
from sqlalchemy.types import TEXT, TypeDecorator, TypeEngine
from typing_extensions import TypedDict
from typing_extensions import TypedDict, TypeGuard

import _thread # pylint: disable=C0411
from superset.constants import (
Expand Down Expand Up @@ -1275,7 +1275,7 @@ def backend() -> str:
return get_example_database().backend


def is_adhoc_metric(metric: Metric) -> bool:
def is_adhoc_metric(metric: Metric) -> TypeGuard[AdhocMetric]:
return isinstance(metric, dict) and "expressionType" in metric


Expand All @@ -1288,7 +1288,6 @@ def get_metric_name(metric: Metric) -> str:
:raises ValueError: if metric object is invalid
"""
if is_adhoc_metric(metric):
metric = cast(AdhocMetric, metric)
label = metric.get("label")
if label:
return label
Expand All @@ -1306,7 +1305,7 @@ def get_metric_name(metric: Metric) -> str:
if column_name:
return column_name
raise ValueError(__("Invalid metric object"))
return cast(str, metric)
return metric # type: ignore


def get_metric_names(metrics: Sequence[Metric]) -> List[str]:
Expand Down
7 changes: 7 additions & 0 deletions tests/unit_tests/core_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
GenericDataType,
get_metric_name,
get_metric_names,
is_adhoc_metric,
)

STR_METRIC = "my_metric"
Expand Down Expand Up @@ -91,3 +92,9 @@ def test_get_metric_names():
assert get_metric_names(
[STR_METRIC, SIMPLE_SUM_ADHOC_METRIC, SQL_ADHOC_METRIC]
) == ["my_metric", "my SUM", "my_sql"]


def test_is_adhoc_metric():
assert is_adhoc_metric(STR_METRIC) is False
assert is_adhoc_metric(SIMPLE_SUM_ADHOC_METRIC) is True
assert is_adhoc_metric(SQL_ADHOC_METRIC) is True

0 comments on commit d46dc9a

Please sign in to comment.