Skip to content

Commit

Permalink
feat: API endpoint to import charts (apache#11744)
Browse files Browse the repository at this point in the history
* ImportChartsCommand

* feat: API endpoint to import charts

* Add dispatcher

* Fix docstring
  • Loading branch information
betodealmeida authored Nov 20, 2020
1 parent 2f4f877 commit a3a2a68
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 13 deletions.
56 changes: 56 additions & 0 deletions superset/charts/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
ChartUpdateFailedError,
)
from superset.charts.commands.export import ExportChartsCommand
from superset.charts.commands.importers.dispatcher import ImportChartsCommand
from superset.charts.commands.update import UpdateChartCommand
from superset.charts.dao import ChartDAO
from superset.charts.filters import ChartAllTextFilter, ChartFavoriteFilter, ChartFilter
Expand All @@ -59,6 +60,7 @@
screenshot_query_schema,
thumbnail_query_schema,
)
from superset.commands.exceptions import CommandInvalidError
from superset.constants import RouteMethod
from superset.exceptions import SupersetSecurityException
from superset.extensions import event_logger
Expand Down Expand Up @@ -86,6 +88,7 @@ class ChartRestApi(BaseSupersetModelRestApi):

include_route_methods = RouteMethod.REST_MODEL_VIEW_CRUD_SET | {
RouteMethod.EXPORT,
RouteMethod.IMPORT,
RouteMethod.RELATED,
"bulk_delete", # not using RouteMethod since locally defined
"data",
Expand Down Expand Up @@ -823,3 +826,56 @@ def favorite_status(self, **kwargs: Any) -> Response:
for request_id in requested_ids
]
return self.response(200, result=res)

@expose("/import/", methods=["POST"])
@protect()
@safe
@statsd_metrics
def import_(self) -> Response:
"""Import chart(s) with associated datasets and databases
---
post:
requestBody:
content:
application/zip:
schema:
type: string
format: binary
responses:
200:
description: Chart import result
content:
application/json:
schema:
type: object
properties:
message:
type: string
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
upload = request.files.get("file")
if not upload:
return self.response_400()
with ZipFile(upload) as bundle:
contents = {
file_name: bundle.read(file_name).decode()
for file_name in bundle.namelist()
}

command = ImportChartsCommand(contents)
try:
command.run()
return self.response(200, message="OK")
except CommandInvalidError as exc:
logger.warning("Import chart failed")
return self.response_422(message=exc.normalized_messages())
except Exception as exc: # pylint: disable=broad-except
logger.exception("Import chart failed")
return self.response_500(message=str(exc))
70 changes: 70 additions & 0 deletions superset/charts/commands/importers/dispatcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import logging
from typing import Any, Dict

from marshmallow.exceptions import ValidationError

from superset.charts.commands.importers import v1
from superset.commands.base import BaseCommand
from superset.commands.exceptions import CommandInvalidError
from superset.commands.importers.exceptions import IncorrectVersionError

logger = logging.getLogger(__name__)

command_versions = [
v1.ImportChartsCommand,
]


class ImportChartsCommand(BaseCommand):
"""
Import charts.
This command dispatches the import to different versions of the command
until it finds one that matches.
"""

# pylint: disable=unused-argument
def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
self.contents = contents

def run(self) -> None:
# iterate over all commands until we find a version that can
# handle the contents
for version in command_versions:
command = version(self.contents)
try:
command.run()
return
except IncorrectVersionError:
# file is not handled by command, skip
pass
except (CommandInvalidError, ValidationError) as exc:
# found right version, but file is invalid
logger.info("Command failed validation")
raise exc
except Exception as exc:
# validation succeeded but something went wrong
logger.exception("Error running import command")
raise exc

raise CommandInvalidError("Could not find a valid command to import file")

def validate(self) -> None:
pass
94 changes: 89 additions & 5 deletions tests/charts/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@
from datetime import datetime
from io import BytesIO
from unittest import mock
from zipfile import is_zipfile
from zipfile import is_zipfile, ZipFile

import humanize
import prison
import pytest
import yaml
from sqlalchemy import and_
from sqlalchemy.sql import func

Expand All @@ -35,12 +36,19 @@
from tests.test_app import app
from superset.connectors.connector_registry import ConnectorRegistry
from superset.extensions import db, security_manager
from superset.models.core import FavStar, FavStarClassName
from superset.models.core import Database, FavStar, FavStarClassName
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.utils import core as utils
from tests.base_api_tests import ApiOwnersTestCaseMixin
from tests.base_tests import SupersetTestCase
from tests.fixtures.importexport import (
chart_config,
chart_metadata_config,
database_config,
dataset_config,
dataset_metadata_config,
)
from tests.fixtures.query_context import get_query_context

CHART_DATA_URI = "api/v1/chart/data"
Expand Down Expand Up @@ -1131,7 +1139,7 @@ def test_chart_data_jinja_filter_request(self):

def test_export_chart(self):
"""
Chart API: Test export dataset
Chart API: Test export chart
"""
example_chart = db.session.query(Slice).all()[0]
argument = [example_chart.id]
Expand All @@ -1147,7 +1155,7 @@ def test_export_chart(self):

def test_export_chart_not_found(self):
"""
Dataset API: Test export dataset not found
Chart API: Test export chart not found
"""
# Just one does not exist and we get 404
argument = [-1, 1]
Expand All @@ -1159,7 +1167,7 @@ def test_export_chart_not_found(self):

def test_export_chart_gamma(self):
"""
Dataset API: Test export dataset has gamma
Chart API: Test export chart has gamma
"""
example_chart = db.session.query(Slice).all()[0]
argument = [example_chart.id]
Expand All @@ -1169,3 +1177,79 @@ def test_export_chart_gamma(self):
rv = self.client.get(uri)

assert rv.status_code == 404

def test_import_chart(self):
"""
Chart API: Test import chart
"""
self.login(username="admin")
uri = "api/v1/chart/import/"

buf = BytesIO()
with ZipFile(buf, "w") as bundle:
with bundle.open("metadata.yaml", "w") as fp:
fp.write(yaml.safe_dump(chart_metadata_config).encode())
with bundle.open("databases/imported_database.yaml", "w") as fp:
fp.write(yaml.safe_dump(database_config).encode())
with bundle.open("datasets/imported_dataset.yaml", "w") as fp:
fp.write(yaml.safe_dump(dataset_config).encode())
with bundle.open("charts/imported_chart.yaml", "w") as fp:
fp.write(yaml.safe_dump(chart_config).encode())
buf.seek(0)

form_data = {
"file": (buf, "chart_export.zip"),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))

assert rv.status_code == 200
assert response == {"message": "OK"}

database = (
db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
)
assert database.database_name == "imported_database"

assert len(database.tables) == 1
dataset = database.tables[0]
assert dataset.table_name == "imported_dataset"
assert str(dataset.uuid) == dataset_config["uuid"]

chart = db.session.query(Slice).filter_by(uuid=chart_config["uuid"]).one()
assert chart.table == dataset

db.session.delete(chart)
db.session.delete(dataset)
db.session.delete(database)
db.session.commit()

def test_import_chart_invalid(self):
"""
Chart API: Test import invalid chart
"""
self.login(username="admin")
uri = "api/v1/chart/import/"

buf = BytesIO()
with ZipFile(buf, "w") as bundle:
with bundle.open("metadata.yaml", "w") as fp:
fp.write(yaml.safe_dump(dataset_metadata_config).encode())
with bundle.open("databases/imported_database.yaml", "w") as fp:
fp.write(yaml.safe_dump(database_config).encode())
with bundle.open("datasets/imported_dataset.yaml", "w") as fp:
fp.write(yaml.safe_dump(dataset_config).encode())
with bundle.open("charts/imported_chart.yaml", "w") as fp:
fp.write(yaml.safe_dump(chart_config).encode())
buf.seek(0)

form_data = {
"file": (buf, "chart_export.zip"),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))

assert rv.status_code == 422
assert response == {
"message": {"metadata.yaml": {"type": ["Must be equal to Slice."]}}
}
4 changes: 2 additions & 2 deletions tests/databases/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,7 @@ def test_import_database(self):
fp.write(yaml.safe_dump(database_metadata_config).encode())
with bundle.open("databases/imported_database.yaml", "w") as fp:
fp.write(yaml.safe_dump(database_config).encode())
with bundle.open("datasets/import_dataset.yaml", "w") as fp:
with bundle.open("datasets/imported_dataset.yaml", "w") as fp:
fp.write(yaml.safe_dump(dataset_config).encode())
buf.seek(0)

Expand Down Expand Up @@ -880,7 +880,7 @@ def test_import_database_invalid(self):
fp.write(yaml.safe_dump(dataset_metadata_config).encode())
with bundle.open("databases/imported_database.yaml", "w") as fp:
fp.write(yaml.safe_dump(database_config).encode())
with bundle.open("datasets/import_dataset.yaml", "w") as fp:
with bundle.open("datasets/imported_dataset.yaml", "w") as fp:
fp.write(yaml.safe_dump(dataset_config).encode())
buf.seek(0)

Expand Down
12 changes: 6 additions & 6 deletions tests/datasets/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,7 +1176,7 @@ def test_get_datasets_custom_filter_sql(self):
for table_name in self.fixture_tables_names:
assert table_name in [ds["table_name"] for ds in data["result"]]

def test_import_dataset(self):
def test_imported_dataset(self):
"""
Dataset API: Test import dataset
"""
Expand All @@ -1189,7 +1189,7 @@ def test_import_dataset(self):
fp.write(yaml.safe_dump(dataset_metadata_config).encode())
with bundle.open("databases/imported_database.yaml", "w") as fp:
fp.write(yaml.safe_dump(database_config).encode())
with bundle.open("datasets/import_dataset.yaml", "w") as fp:
with bundle.open("datasets/imported_dataset.yaml", "w") as fp:
fp.write(yaml.safe_dump(dataset_config).encode())
buf.seek(0)

Expand All @@ -1216,7 +1216,7 @@ def test_import_dataset(self):
db.session.delete(database)
db.session.commit()

def test_import_dataset_invalid(self):
def test_imported_dataset_invalid(self):
"""
Dataset API: Test import invalid dataset
"""
Expand All @@ -1229,7 +1229,7 @@ def test_import_dataset_invalid(self):
fp.write(yaml.safe_dump(database_metadata_config).encode())
with bundle.open("databases/imported_database.yaml", "w") as fp:
fp.write(yaml.safe_dump(database_config).encode())
with bundle.open("datasets/import_dataset.yaml", "w") as fp:
with bundle.open("datasets/imported_dataset.yaml", "w") as fp:
fp.write(yaml.safe_dump(dataset_config).encode())
buf.seek(0)

Expand All @@ -1244,7 +1244,7 @@ def test_import_dataset_invalid(self):
"message": {"metadata.yaml": {"type": ["Must be equal to SqlaTable."]}}
}

def test_import_dataset_invalid_v0_validation(self):
def test_imported_dataset_invalid_v0_validation(self):
"""
Dataset API: Test import invalid dataset
"""
Expand All @@ -1255,7 +1255,7 @@ def test_import_dataset_invalid_v0_validation(self):
with ZipFile(buf, "w") as bundle:
with bundle.open("databases/imported_database.yaml", "w") as fp:
fp.write(yaml.safe_dump(database_config).encode())
with bundle.open("datasets/import_dataset.yaml", "w") as fp:
with bundle.open("datasets/imported_dataset.yaml", "w") as fp:
fp.write(yaml.safe_dump(dataset_config).encode())
buf.seek(0)

Expand Down

0 comments on commit a3a2a68

Please sign in to comment.