From 73cd2ea3b17574f8fef1112aa5e5b39f843882f6 Mon Sep 17 00:00:00 2001 From: Bogdan Date: Tue, 11 Oct 2016 17:54:40 -0700 Subject: [PATCH] Import / export of the dashboards. (#1197) * Implement import / export dashboard functionality. * Address comments from discussion. * Add function descriptions. * Minor fixes * Fix tests for python 3. * Export datasources. * Implement tables import. * Json.loads does not support trailing commas. * Improve alter_dict func * Resolve comments. * Refactor tests * Move params_dict and alter_params to the ImportMixin * Fix flask menues. --- caravel/__init__.py | 6 + caravel/data/__init__.py | 4 +- .../b46fa1b0b39e_add_params_to_tables.py | 28 ++ caravel/models.py | 321 ++++++++++++++- caravel/source_registry.py | 28 ++ .../templates/caravel/export_dashboards.html | 6 + .../templates/caravel/import_dashboards.html | 23 ++ caravel/views.py | 68 +++- tests/core_tests.py | 1 - tests/import_export_tests.py | 368 ++++++++++++++++++ 10 files changed, 827 insertions(+), 26 deletions(-) create mode 100644 caravel/migrations/versions/b46fa1b0b39e_add_params_to_tables.py create mode 100644 caravel/templates/caravel/export_dashboards.html create mode 100644 caravel/templates/caravel/import_dashboards.html create mode 100644 tests/import_export_tests.py diff --git a/caravel/__init__.py b/caravel/__init__.py index 1fcfe43acb992..bb74091a714a5 100644 --- a/caravel/__init__.py +++ b/caravel/__init__.py @@ -82,6 +82,12 @@ def checkout(dbapi_con, con_record, con_proxy): if app.config.get('ENABLE_PROXY_FIX'): app.wsgi_app = ProxyFix(app.wsgi_app) +if app.config.get('UPLOAD_FOLDER'): + try: + os.makedirs(app.config.get('UPLOAD_FOLDER')) + except OSError: + pass + class MyIndexView(IndexView): @expose('/') diff --git a/caravel/data/__init__.py b/caravel/data/__init__.py index 974ce2e898bae..0f9e2a308efa3 100644 --- a/caravel/data/__init__.py +++ b/caravel/data/__init__.py @@ -1058,7 +1058,7 @@ def load_multiformat_time_series_data(): 'string0': ['%Y-%m-%d %H:%M:%S.%f', None], 'string3': ['%Y/%m/%d%H:%M:%S.%f', None], } - for col in obj.table_columns: + for col in obj.columns: dttm_and_expr = dttm_and_expr_dict[col.column_name] col.python_date_format = dttm_and_expr[0] col.dbatabase_expr = dttm_and_expr[1] @@ -1069,7 +1069,7 @@ def load_multiformat_time_series_data(): tbl = obj print("Creating some slices") - for i, col in enumerate(tbl.table_columns): + for i, col in enumerate(tbl.columns): slice_data = { "granularity_sqla": col.column_name, "datasource_id": "8", diff --git a/caravel/migrations/versions/b46fa1b0b39e_add_params_to_tables.py b/caravel/migrations/versions/b46fa1b0b39e_add_params_to_tables.py new file mode 100644 index 0000000000000..9d02ec5b4b105 --- /dev/null +++ b/caravel/migrations/versions/b46fa1b0b39e_add_params_to_tables.py @@ -0,0 +1,28 @@ +"""Add json_metadata to the tables table. + +Revision ID: b46fa1b0b39e +Revises: ef8843b41dac +Create Date: 2016-10-05 11:30:31.748238 + +""" + +# revision identifiers, used by Alembic. +revision = 'b46fa1b0b39e' +down_revision = 'ef8843b41dac' + +from alembic import op +import logging +import sqlalchemy as sa + + +def upgrade(): + op.add_column('tables', + sa.Column('params', sa.Text(), nullable=True)) + + +def downgrade(): + try: + op.drop_column('tables', 'params') + except Exception as e: + logging.warning(str(e)) + diff --git a/caravel/models.py b/caravel/models.py index f160e038d9605..706fc49a38410 100644 --- a/caravel/models.py +++ b/caravel/models.py @@ -7,6 +7,7 @@ import functools import json import logging +import pickle import re import textwrap from collections import namedtuple @@ -18,6 +19,8 @@ import requests import sqlalchemy as sqla from sqlalchemy.engine.url import make_url +from sqlalchemy.orm import subqueryload + import sqlparse from dateutil.parser import parse @@ -41,6 +44,7 @@ from sqlalchemy.ext.compiler import compiles from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.orm import backref, relationship +from sqlalchemy.orm.session import make_transient from sqlalchemy.sql import table, literal_column, text, column from sqlalchemy.sql.expression import ColumnClause, TextAsFrom from sqlalchemy_utils import EncryptedType @@ -70,6 +74,31 @@ def __init__(self, name, field_names, function): self.name = name +class ImportMixin(object): + def override(self, obj): + """Overrides the plain fields of the dashboard.""" + for field in obj.__class__.export_fields: + setattr(self, field, getattr(obj, field)) + + def copy(self): + """Creates a copy of the dashboard without relationships.""" + new_obj = self.__class__() + new_obj.override(self) + return new_obj + + def alter_params(self, **kwargs): + d = self.params_dict + d.update(kwargs) + self.params = json.dumps(d) + + @property + def params_dict(self): + if self.params: + return json.loads(self.params) + else: + return {} + + class AuditMixinNullable(AuditMixin): """Altering the AuditMixin to use nullable fields @@ -149,7 +178,7 @@ class CssTemplate(Model, AuditMixinNullable): ) -class Slice(Model, AuditMixinNullable): +class Slice(Model, AuditMixinNullable, ImportMixin): """A slice is essentially a report or a view on data""" @@ -166,6 +195,9 @@ class Slice(Model, AuditMixinNullable): perm = Column(String(2000)) owners = relationship("User", secondary=slice_user) + export_fields = ('slice_name', 'datasource_type', 'datasource_name', + 'viz_type', 'params', 'cache_timeout') + def __repr__(self): return self.slice_name @@ -283,6 +315,42 @@ def get_viz(self, url_params_multidict=None): slice_=self ) + @classmethod + def import_obj(cls, slc_to_import, import_time=None): + """Inserts or overrides slc in the database. + + remote_id and import_time fields in params_dict are set to track the + slice origin and ensure correct overrides for multiple imports. + Slice.perm is used to find the datasources and connect them. + """ + session = db.session + make_transient(slc_to_import) + slc_to_import.dashboards = [] + slc_to_import.alter_params( + remote_id=slc_to_import.id, import_time=import_time) + + # find if the slice was already imported + slc_to_override = None + for slc in session.query(Slice).all(): + if ('remote_id' in slc.params_dict and + slc.params_dict['remote_id'] == slc_to_import.id): + slc_to_override = slc + + slc_to_import.id = None + params = slc_to_import.params_dict + slc_to_import.datasource_id = SourceRegistry.get_datasource_by_name( + session, slc_to_import.datasource_type, params['datasource_name'], + params['schema'], params['database_name']).id + if slc_to_override: + slc_to_override.override(slc_to_import) + session.flush() + return slc_to_override.id + else: + session.add(slc_to_import) + logging.info('Final slice: {}'.format(slc_to_import.to_json())) + session.flush() + return slc_to_import.id + def set_perm(mapper, connection, target): # noqa src_class = target.cls_model @@ -309,7 +377,7 @@ def set_perm(mapper, connection, target): # noqa ) -class Dashboard(Model, AuditMixinNullable): +class Dashboard(Model, AuditMixinNullable, ImportMixin): """The dashboard object!""" @@ -325,6 +393,9 @@ class Dashboard(Model, AuditMixinNullable): 'Slice', secondary=dashboard_slices, backref='dashboards') owners = relationship("User", secondary=dashboard_user) + export_fields = ('dashboard_title', 'position_json', 'json_metadata', + 'description', 'css', 'slug', 'slices') + def __repr__(self): return self.dashboard_title @@ -340,13 +411,6 @@ def url(self): def datasources(self): return {slc.datasource for slc in self.slices} - @property - def metadata_dejson(self): - if self.json_metadata: - return json.loads(self.json_metadata) - else: - return {} - @property def sqla_metadata(self): metadata = MetaData(bind=self.get_sqla_engine()) @@ -361,7 +425,7 @@ def dashboard_link(self): def json_data(self): d = { 'id': self.id, - 'metadata': self.metadata_dejson, + 'metadata': self.params_dict, 'dashboard_title': self.dashboard_title, 'slug': self.slug, 'slices': [slc.data for slc in self.slices], @@ -369,6 +433,107 @@ def json_data(self): } return json.dumps(d) + @property + def params(self): + return self.json_metadata + + @params.setter + def params(self, value): + self.json_metadata = value + + @classmethod + def import_obj(cls, dashboard_to_import, import_time=None): + """Imports the dashboard from the object to the database. + + Once dashboard is imported, json_metadata field is extended and stores + remote_id and import_time. It helps to decide if the dashboard has to + be overridden or just copies over. Slices that belong to this + dashboard will be wired to existing tables. This function can be used + to import/export dashboards between multiple caravel instances. + Audit metadata isn't copies over. + """ + logging.info('Started import of the dashboard: {}' + .format(dashboard_to_import.to_json())) + session = db.session + logging.info('Dashboard has {} slices' + .format(len(dashboard_to_import.slices))) + # copy slices object as Slice.import_slice will mutate the slice + # and will remove the existing dashboard - slice association + slices = copy(dashboard_to_import.slices) + slice_ids = set() + for slc in slices: + logging.info('Importing slice {} from the dashboard: {}'.format( + slc.to_json(), dashboard_to_import.dashboard_title)) + slice_ids.add(Slice.import_obj(slc, import_time=import_time)) + + # override the dashboard + existing_dashboard = None + for dash in session.query(Dashboard).all(): + if ('remote_id' in dash.params_dict and + dash.params_dict['remote_id'] == + dashboard_to_import.id): + existing_dashboard = dash + + dashboard_to_import.id = None + dashboard_to_import.alter_params(import_time=import_time) + new_slices = session.query(Slice).filter(Slice.id.in_(slice_ids)).all() + + if existing_dashboard: + existing_dashboard.override(dashboard_to_import) + existing_dashboard.slices = new_slices + session.flush() + return existing_dashboard.id + else: + # session.add(dashboard_to_import) causes sqlachemy failures + # related to the attached users / slices. Creating new object + # allows to avoid conflicts in the sql alchemy state. + copied_dash = dashboard_to_import.copy() + copied_dash.slices = new_slices + session.add(copied_dash) + session.flush() + return copied_dash.id + + @classmethod + def export_dashboards(cls, dashboard_ids): + copied_dashboards = [] + datasource_ids = set() + for dashboard_id in dashboard_ids: + # make sure that dashboard_id is an integer + dashboard_id = int(dashboard_id) + copied_dashboard = ( + db.session.query(Dashboard) + .options(subqueryload(Dashboard.slices)) + .filter_by(id=dashboard_id).first() + ) + make_transient(copied_dashboard) + for slc in copied_dashboard.slices: + datasource_ids.add((slc.datasource_id, slc.datasource_type)) + # add extra params for the import + slc.alter_params( + remote_id=slc.id, + datasource_name=slc.datasource.name, + schema=slc.datasource.name, + database_name=slc.datasource.database.database_name, + ) + copied_dashboard.alter_params(remote_id=dashboard_id) + copied_dashboards.append(copied_dashboard) + + eager_datasources = [] + for dashboard_id, dashboard_type in datasource_ids: + eager_datasource = SourceRegistry.get_eager_datasource( + db.session, dashboard_type, dashboard_id) + eager_datasource.alter_params( + remote_id=eager_datasource.id, + database_name=eager_datasource.database.database_name, + ) + make_transient(eager_datasource) + eager_datasources.append(eager_datasource) + + return pickle.dumps({ + 'dashboards': copied_dashboards, + 'datasources': eager_datasources, + }) + class Queryable(object): @@ -433,6 +598,10 @@ class Database(Model, AuditMixinNullable): def __repr__(self): return self.database_name + @property + def name(self): + return self.database_name + @property def backend(self): url = make_url(self.sqlalchemy_uri_decrypted) @@ -665,7 +834,7 @@ def perm(self): "[{obj.database_name}].(id:{obj.id})").format(obj=self) -class SqlaTable(Model, Queryable, AuditMixinNullable): +class SqlaTable(Model, Queryable, AuditMixinNullable, ImportMixin): """An ORM object for SqlAlchemy table references""" @@ -689,9 +858,13 @@ class SqlaTable(Model, Queryable, AuditMixinNullable): cache_timeout = Column(Integer) schema = Column(String(255)) sql = Column(Text) - table_columns = relationship("TableColumn", back_populates="table") + params = Column(Text) baselink = "tablemodelview" + export_fields = ( + 'table_name', 'main_dttm_col', 'description', 'default_endpoint', + 'database_id', 'is_featured', 'offset', 'cache_timeout', 'schema', + 'sql', 'params') __table_args__ = ( sqla.UniqueConstraint( @@ -773,7 +946,7 @@ def time_column_grains(self): } def get_col(self, col_name): - columns = self.table_columns + columns = self.columns for col in columns: if col_name == col.column_name: return col @@ -1062,8 +1235,67 @@ def fetch_metadata(self): if not self.main_dttm_col: self.main_dttm_col = any_date_col + @classmethod + def import_obj(cls, datasource_to_import, import_time=None): + """Imports the datasource from the object to the database. + + Metrics and columns and datasource will be overrided if exists. + This function can be used to import/export dashboards between multiple + caravel instances. Audit metadata isn't copies over. + """ + session = db.session + make_transient(datasource_to_import) + logging.info('Started import of the datasource: {}' + .format(datasource_to_import.to_json())) + + datasource_to_import.id = None + database_name = datasource_to_import.params_dict['database_name'] + datasource_to_import.database_id = session.query(Database).filter_by( + database_name=database_name).one().id + datasource_to_import.alter_params(import_time=import_time) + + # override the datasource + datasource = ( + session.query(SqlaTable).join(Database) + .filter( + SqlaTable.table_name == datasource_to_import.table_name, + SqlaTable.schema == datasource_to_import.schema, + Database.id == datasource_to_import.database_id, + ) + .first() + ) + + if datasource: + datasource.override(datasource_to_import) + session.flush() + else: + datasource = datasource_to_import.copy() + session.add(datasource) + session.flush() -class SqlMetric(Model, AuditMixinNullable): + for m in datasource_to_import.metrics: + new_m = m.copy() + new_m.table_id = datasource.id + logging.info('Importing metric {} from the datasource: {}'.format( + new_m.to_json(), datasource_to_import.full_name)) + imported_m = SqlMetric.import_obj(new_m) + if imported_m not in datasource.metrics: + datasource.metrics.append(imported_m) + + for c in datasource_to_import.columns: + new_c = c.copy() + new_c.table_id = datasource.id + logging.info('Importing column {} from the datasource: {}'.format( + new_c.to_json(), datasource_to_import.full_name)) + imported_c = TableColumn.import_obj(new_c) + if imported_c not in datasource.columns: + datasource.columns.append(imported_c) + db.session.flush() + + return datasource.id + + +class SqlMetric(Model, AuditMixinNullable, ImportMixin): """ORM object for metrics, each table can have multiple metrics""" @@ -1082,6 +1314,10 @@ class SqlMetric(Model, AuditMixinNullable): is_restricted = Column(Boolean, default=False, nullable=True) d3format = Column(String(128)) + export_fields = ( + 'metric_name', 'verbose_name', 'metric_type', 'table_id', 'expression', + 'description', 'is_restricted', 'd3format') + @property def sqla_col(self): name = self.metric_name @@ -1094,8 +1330,28 @@ def perm(self): ).format(obj=self, parent_name=self.table.full_name) if self.table else None + @classmethod + def import_obj(cls, metric_to_import): + session = db.session + make_transient(metric_to_import) + metric_to_import.id = None + + # find if the column was already imported + existing_metric = session.query(SqlMetric).filter( + SqlMetric.table_id == metric_to_import.table_id, + SqlMetric.metric_name == metric_to_import.metric_name).first() + metric_to_import.table = None + if existing_metric: + existing_metric.override(metric_to_import) + session.flush() + return existing_metric + + session.add(metric_to_import) + session.flush() + return metric_to_import + -class TableColumn(Model, AuditMixinNullable): +class TableColumn(Model, AuditMixinNullable, ImportMixin): """ORM object for table columns, each table can have multiple columns""" @@ -1125,6 +1381,12 @@ class TableColumn(Model, AuditMixinNullable): num_types = ('DOUBLE', 'FLOAT', 'INT', 'BIGINT', 'LONG') date_types = ('DATE', 'TIME') str_types = ('VARCHAR', 'STRING', 'CHAR') + export_fields = ( + 'table_id', 'column_name', 'verbose_name', 'is_dttm', 'is_active', + 'type', 'groupby', 'count_distinct', 'sum', 'max', 'min', + 'filterable', 'expression', 'description', 'python_date_format', + 'database_expression' + ) def __repr__(self): return self.column_name @@ -1150,6 +1412,27 @@ def sqla_col(self): col = literal_column(self.expression).label(name) return col + @classmethod + def import_obj(cls, column_to_import): + session = db.session + make_transient(column_to_import) + column_to_import.id = None + column_to_import.table = None + + # find if the column was already imported + existing_column = session.query(TableColumn).filter( + TableColumn.table_id == column_to_import.table_id, + TableColumn.column_name == column_to_import.column_name).first() + column_to_import.table = None + if existing_column: + existing_column.override(column_to_import) + session.flush() + return existing_column + + session.add(column_to_import) + session.flush() + return column_to_import + def dttm_sql_literal(self, dttm): """Convert datetime object to string @@ -1234,6 +1517,10 @@ def refresh_datasources(self): def perm(self): return "[{obj.cluster_name}].(id:{obj.id})".format(obj=self) + @property + def name(self): + return self.cluster_name + class DruidDatasource(Model, AuditMixinNullable, Queryable): @@ -1262,6 +1549,10 @@ class DruidDatasource(Model, AuditMixinNullable, Queryable): offset = Column(Integer, default=0) cache_timeout = Column(Integer) + @property + def database(self): + return self.cluster + @property def metrics_combo(self): return sorted( diff --git a/caravel/source_registry.py b/caravel/source_registry.py index dc1a170d7c68e..6176c9c0c96b1 100644 --- a/caravel/source_registry.py +++ b/caravel/source_registry.py @@ -1,3 +1,4 @@ +from sqlalchemy.orm import subqueryload class SourceRegistry(object): @@ -20,3 +21,30 @@ def get_datasource(cls, datasource_type, datasource_id, session): .filter_by(id=datasource_id) .one() ) + + @classmethod + def get_datasource_by_name(cls, session, datasource_type, datasource_name, + schema, database_name): + datasource_class = SourceRegistry.sources[datasource_type] + datasources = session.query(datasource_class).all() + db_ds = [d for d in datasources if d.database.name == database_name and + d.name == datasource_name and schema == schema] + return db_ds[0] + + @classmethod + def get_eager_datasource(cls, session, datasource_type, datasource_id): + """Returns datasource with columns and metrics.""" + datasource_class = SourceRegistry.sources[datasource_type] + if datasource_type == 'table': + return ( + session.query(datasource_class) + .options( + subqueryload(datasource_class.columns), + subqueryload(datasource_class.metrics) + ) + .filter_by(id=datasource_id) + .one() + ) + # TODO: support druid datasources. + return session.query(datasource_class).filter_by( + id=datasource_id).first() diff --git a/caravel/templates/caravel/export_dashboards.html b/caravel/templates/caravel/export_dashboards.html new file mode 100644 index 0000000000000..83a0c840e09ff --- /dev/null +++ b/caravel/templates/caravel/export_dashboards.html @@ -0,0 +1,6 @@ + diff --git a/caravel/templates/caravel/import_dashboards.html b/caravel/templates/caravel/import_dashboards.html new file mode 100644 index 0000000000000..8365bbe0f7be2 --- /dev/null +++ b/caravel/templates/caravel/import_dashboards.html @@ -0,0 +1,23 @@ +{% extends "caravel/basic.html" %} + +# TODO: move the libs required by flask into the common.js from welcome.js. +{% block head_js %} + {{ super() }} + {% with filename="welcome" %} + {% include "caravel/partials/_script_tag.html" %} + {% endwith %} +{% endblock %} + +{% block title %}{{ _("Import") }}{% endblock %} +{% block body %} + {% include "caravel/flash_wrapper.html" %} +
+ Import the dashboards. +

Import the dashboards.

+
+

+ +

+
+
+{% endblock %} diff --git a/caravel/views.py b/caravel/views.py index 212cf0709b7ab..ea3fed6041917 100755 --- a/caravel/views.py +++ b/caravel/views.py @@ -5,6 +5,8 @@ import json import logging +import os +import pickle import re import sys import time @@ -15,7 +17,8 @@ import sqlalchemy as sqla from flask import ( - g, request, redirect, flash, Response, render_template, Markup) + g, request, make_response, redirect, flash, Response, render_template, + Markup, url_for) from flask_appbuilder import ModelView, CompactCRUDMixin, BaseView, expose from flask_appbuilder.actions import action from flask_appbuilder.models.sqla.interface import SQLAInterface @@ -26,8 +29,9 @@ from flask_appbuilder.models.sqla.filters import BaseFilter from sqlalchemy import create_engine -from werkzeug.routing import BaseConverter +from werkzeug import secure_filename from werkzeug.datastructures import ImmutableMultiDict +from werkzeug.routing import BaseConverter from wtforms.validators import ValidationError import caravel @@ -533,6 +537,16 @@ def pre_update(self, db): self.pre_add(db) +appbuilder.add_link( + 'Import Dashboards', + label=__("Import Dashboards"), + href='/caravel/import_dashboards', + icon="fa-cloud-upload", + category='Manage', + category_label=__("Manage"), + category_icon='fa-wrench',) + + appbuilder.add_view( DatabaseView, "Databases", @@ -658,7 +672,6 @@ class AccessRequestsModelView(CaravelModelView, DeleteMixin): category_label=__("Security"), icon='fa-table',) - appbuilder.add_separator("Sources") @@ -867,13 +880,32 @@ def pre_update(self, obj): def pre_delete(self, obj): check_ownership(obj) + @action("mulexport", "Export", "Export dashboards?", "fa-database") + def mulexport(self, items): + ids = ''.join('&id={}'.format(d.id) for d in items) + return redirect( + '/dashboardmodelview/export_dashboards_form?{}'.format(ids[1:])) + + @expose("/export_dashboards_form") + def download_dashboards(self): + if request.args.get('action') == 'go': + ids = request.args.getlist('id') + return Response( + models.Dashboard.export_dashboards(ids), + headers=generate_download_headers("pickle"), + mimetype="application/text") + return self.render_template( + 'caravel/export_dashboards.html', + dashboards_url='/dashboardmodelview/list' + ) + appbuilder.add_view( DashboardModelView, "Dashboards", label=__("Dashboards"), icon="fa-dashboard", - category="", + category='', category_icon='',) @@ -1053,9 +1085,8 @@ def approve(self): role_to_extend = request.args.get('role_to_extend') session = db.session - datasource_class = SourceRegistry.sources[datasource_type] - datasource = session.query(datasource_class).filter_by( - id=datasource_id).first() + datasource = SourceRegistry.get_datasource( + datasource_type, datasource_id, session) if not datasource: flash(DATASOURCE_MISSING_ERR, "alert") @@ -1149,6 +1180,27 @@ def explore_json(self, datasource_type, datasource_id): status=200, mimetype="application/json") + @expose("/import_dashboards", methods=['GET', 'POST']) + @log_this + def import_dashboards(self): + """Overrides the dashboards using pickled instances from the file.""" + f = request.files.get('file') + if request.method == 'POST' and f: + filename = secure_filename(f.filename) + filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) + f.save(filepath) + current_tt = int(time.time()) + data = pickle.load(open(filepath, 'rb')) + for table in data['datasources']: + models.SqlaTable.import_obj(table, import_time=current_tt) + for dashboard in data['dashboards']: + models.Dashboard.import_obj( + dashboard, import_time=current_tt) + os.remove(filepath) + db.session.commit() + return redirect('/dashboardmodelview/list/') + return self.render_template('caravel/import_dashboards.html') + @log_this @has_access @expose("/explore///") @@ -1478,7 +1530,7 @@ def save_dash(self, dashboard_id): dash.slices = [o for o in dash.slices if o.id in slice_ids] positions = sorted(data['positions'], key=lambda x: int(x['slice_id'])) dash.position_json = json.dumps(positions, indent=4, sort_keys=True) - md = dash.metadata_dejson + md = dash.params_dict if 'filter_immune_slices' not in md: md['filter_immune_slices'] = [] if 'filter_immune_slice_fields' not in md: diff --git a/tests/core_tests.py b/tests/core_tests.py index bfd9700aaed09..bd09bfd8ca22d 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -452,6 +452,5 @@ def test_only_owners_can_save(self): db.session.commit() self.test_save_dash('alpha') - if __name__ == '__main__': unittest.main() diff --git a/tests/import_export_tests.py b/tests/import_export_tests.py new file mode 100644 index 0000000000000..454dbbb73b538 --- /dev/null +++ b/tests/import_export_tests.py @@ -0,0 +1,368 @@ +"""Unit tests for Caravel""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from sqlalchemy.orm.session import make_transient + +import json +import pickle +import unittest + +from caravel import db, models + +from .base_tests import CaravelTestCase + + +class ImportExportTests(CaravelTestCase): + """Testing export import functionality for dashboards""" + + def __init__(self, *args, **kwargs): + super(ImportExportTests, self).__init__(*args, **kwargs) + + @classmethod + def delete_imports(cls): + # Imported data clean up + session = db.session + for slc in session.query(models.Slice): + if 'remote_id' in slc.params_dict: + session.delete(slc) + for dash in session.query(models.Dashboard): + if 'remote_id' in dash.params_dict: + session.delete(dash) + for table in session.query(models.SqlaTable): + if 'remote_id' in table.params_dict: + session.delete(table) + session.commit() + + @classmethod + def setUpClass(cls): + cls.delete_imports() + + @classmethod + def tearDownClass(cls): + cls.delete_imports() + + def create_slice(self, name, ds_id=None, id=None, db_name='main', + table_name='wb_health_population'): + params = { + 'num_period_compare': '10', + 'remote_id': id, + 'datasource_name': table_name, + 'database_name': db_name, + 'schema': '', + } + + if table_name and not ds_id: + table = self.get_table_by_name(table_name) + if table: + ds_id = table.id + + return models.Slice( + slice_name=name, + datasource_type='table', + viz_type='bubble', + params=json.dumps(params), + datasource_id=ds_id, + id=id + ) + + def create_dashboard(self, title, id=0, slcs=[]): + json_metadata = {'remote_id': id} + return models.Dashboard( + id=id, + dashboard_title=title, + slices=slcs, + position_json='{"size_y": 2, "size_x": 2}', + slug='{}_imported'.format(title.lower()), + json_metadata=json.dumps(json_metadata) + ) + + def create_table(self, name, schema='', id=0, cols_names=[], metric_names=[]): + params = {'remote_id': id, 'database_name': 'main'} + table = models.SqlaTable( + id=id, + schema=schema, + table_name=name, + params=json.dumps(params) + ) + for col_name in cols_names: + table.columns.append( + models.TableColumn(column_name=col_name)) + for metric_name in metric_names: + table.metrics.append(models.SqlMetric(metric_name=metric_name)) + return table + + def get_slice(self, slc_id): + return db.session.query(models.Slice).filter_by(id=slc_id).first() + + def get_dash(self, dash_id): + return db.session.query(models.Dashboard).filter_by( + id=dash_id).first() + + def get_dash_by_slug(self, dash_slug): + return db.session.query(models.Dashboard).filter_by( + slug=dash_slug).first() + + def get_table(self, table_id): + return db.session.query(models.SqlaTable).filter_by( + id=table_id).first() + + def get_table_by_name(self, name): + return db.session.query(models.SqlaTable).filter_by( + table_name=name).first() + + def assert_dash_equals(self, expected_dash, actual_dash): + self.assertEquals(expected_dash.slug, actual_dash.slug) + self.assertEquals( + expected_dash.dashboard_title, actual_dash.dashboard_title) + self.assertEquals( + expected_dash.position_json, actual_dash.position_json) + self.assertEquals( + len(expected_dash.slices), len(actual_dash.slices)) + expected_slices = sorted( + expected_dash.slices, key=lambda s: s.slice_name) + actual_slices = sorted( + actual_dash.slices, key=lambda s: s.slice_name) + for e_slc, a_slc in zip(expected_slices, actual_slices): + self.assert_slice_equals(e_slc, a_slc) + + def assert_table_equals(self, expected_ds, actual_ds): + self.assertEquals(expected_ds.table_name, actual_ds.table_name) + self.assertEquals(expected_ds.main_dttm_col, actual_ds.main_dttm_col) + self.assertEquals(expected_ds.schema, actual_ds.schema) + self.assertEquals(len(expected_ds.metrics), len(actual_ds.metrics)) + self.assertEquals(len(expected_ds.columns), len(actual_ds.columns)) + self.assertEquals( + set([c.column_name for c in expected_ds.columns]), + set([c.column_name for c in actual_ds.columns])) + self.assertEquals( + set([m.metric_name for m in expected_ds.metrics]), + set([m.metric_name for m in actual_ds.metrics])) + + def assert_slice_equals(self, expected_slc, actual_slc): + self.assertEquals(actual_slc.datasource.perm, actual_slc.perm) + self.assertEquals(expected_slc.slice_name, actual_slc.slice_name) + self.assertEquals( + expected_slc.datasource_type, actual_slc.datasource_type) + self.assertEquals(expected_slc.viz_type, actual_slc.viz_type) + self.assertEquals( + json.loads(expected_slc.params), json.loads(actual_slc.params)) + + def test_export_1_dashboard(self): + birth_dash = self.get_dash_by_slug('births') + export_dash_url = ( + '/dashboardmodelview/export_dashboards_form?id={}&action=go' + .format(birth_dash.id) + ) + resp = self.client.get(export_dash_url) + exported_dashboards = pickle.loads(resp.data)['dashboards'] + self.assert_dash_equals(birth_dash, exported_dashboards[0]) + self.assertEquals( + birth_dash.id, + json.loads(exported_dashboards[0].json_metadata)['remote_id']) + + exported_tables = pickle.loads(resp.data)['datasources'] + self.assertEquals(1, len(exported_tables)) + self.assert_table_equals( + self.get_table_by_name('birth_names'), exported_tables[0]) + + def test_export_2_dashboards(self): + birth_dash = self.get_dash_by_slug('births') + world_health_dash = self.get_dash_by_slug('world_health') + export_dash_url = ( + '/dashboardmodelview/export_dashboards_form?id={}&id={}&action=go' + .format(birth_dash.id, world_health_dash.id)) + resp = self.client.get(export_dash_url) + exported_dashboards = sorted(pickle.loads(resp.data)['dashboards'], + key=lambda d: d.dashboard_title) + self.assertEquals(2, len(exported_dashboards)) + self.assert_dash_equals(birth_dash, exported_dashboards[0]) + self.assertEquals( + birth_dash.id, + json.loads(exported_dashboards[0].json_metadata)['remote_id'] + ) + + self.assert_dash_equals(world_health_dash, exported_dashboards[1]) + self.assertEquals( + world_health_dash.id, + json.loads(exported_dashboards[1].json_metadata)['remote_id'] + ) + + exported_tables = sorted( + pickle.loads(resp.data)['datasources'], key=lambda t: t.table_name) + self.assertEquals(2, len(exported_tables)) + self.assert_table_equals( + self.get_table_by_name('birth_names'), exported_tables[0]) + self.assert_table_equals( + self.get_table_by_name('wb_health_population'), exported_tables[1]) + + def test_import_1_slice(self): + expected_slice = self.create_slice('Import Me', id=10001); + slc_id = models.Slice.import_obj(expected_slice, import_time=1989) + self.assert_slice_equals(expected_slice, self.get_slice(slc_id)) + + table_id = self.get_table_by_name('wb_health_population').id + self.assertEquals(table_id, self.get_slice(slc_id).datasource_id) + + def test_import_2_slices_for_same_table(self): + table_id = self.get_table_by_name('wb_health_population').id + # table_id != 666, import func will have to find the table + slc_1 = self.create_slice('Import Me 1', ds_id=666, id=10002) + slc_id_1 = models.Slice.import_obj(slc_1) + slc_2 = self.create_slice('Import Me 2', ds_id=666, id=10003) + slc_id_2 = models.Slice.import_obj(slc_2) + + imported_slc_1 = self.get_slice(slc_id_1) + imported_slc_2 = self.get_slice(slc_id_2) + self.assertEquals(table_id, imported_slc_1.datasource_id) + self.assert_slice_equals(slc_1, imported_slc_1) + + self.assertEquals(table_id, imported_slc_2.datasource_id) + self.assert_slice_equals(slc_2, imported_slc_2) + + def test_import_slices_for_non_existent_table(self): + with self.assertRaises(IndexError): + models.Slice.import_obj(self.create_slice( + 'Import Me 3', id=10004, table_name='non_existent')) + + def test_import_slices_override(self): + slc = self.create_slice('Import Me New', id=10005) + slc_1_id = models.Slice.import_obj(slc, import_time=1990) + slc.slice_name = 'Import Me New' + slc_2_id = models.Slice.import_obj( + self.create_slice('Import Me New', id=10005), import_time=1990) + self.assertEquals(slc_1_id, slc_2_id) + imported_slc = self.get_slice(slc_2_id) + self.assert_slice_equals(slc, imported_slc) + + def test_import_empty_dashboard(self): + empty_dash = self.create_dashboard('empty_dashboard', id=10001) + imported_dash_id = models.Dashboard.import_obj( + empty_dash, import_time=1989) + imported_dash = self.get_dash(imported_dash_id) + self.assert_dash_equals(empty_dash, imported_dash) + + def test_import_dashboard_1_slice(self): + slc = self.create_slice('health_slc', id=10006) + dash_with_1_slice = self.create_dashboard( + 'dash_with_1_slice', slcs=[slc], id=10002) + imported_dash_id = models.Dashboard.import_obj( + dash_with_1_slice, import_time=1990) + imported_dash = self.get_dash(imported_dash_id) + + expected_dash = self.create_dashboard( + 'dash_with_1_slice', slcs=[slc], id=10002) + make_transient(expected_dash) + self.assert_dash_equals(expected_dash, imported_dash) + self.assertEquals({"remote_id": 10002, "import_time": 1990}, + json.loads(imported_dash.json_metadata)) + + def test_import_dashboard_2_slices(self): + e_slc = self.create_slice('e_slc', id=10007, table_name='energy_usage') + b_slc = self.create_slice('b_slc', id=10008, table_name='birth_names') + dash_with_2_slices = self.create_dashboard( + 'dash_with_2_slices', slcs=[e_slc, b_slc], id=10003) + imported_dash_id = models.Dashboard.import_obj( + dash_with_2_slices, import_time=1991) + imported_dash = self.get_dash(imported_dash_id) + + expected_dash = self.create_dashboard( + 'dash_with_2_slices', slcs=[e_slc, b_slc], id=10003) + make_transient(expected_dash) + self.assert_dash_equals(imported_dash, expected_dash) + self.assertEquals({"remote_id": 10003, "import_time": 1991}, + json.loads(imported_dash.json_metadata)) + + def test_import_override_dashboard_2_slices(self): + e_slc = self.create_slice('e_slc', id=10009, table_name='energy_usage') + b_slc = self.create_slice('b_slc', id=10010, table_name='birth_names') + dash_to_import = self.create_dashboard( + 'override_dashboard', slcs=[e_slc, b_slc], id=10004) + imported_dash_id_1 = models.Dashboard.import_obj( + dash_to_import, import_time=1992) + + # create new instances of the slices + e_slc = self.create_slice( + 'e_slc', id=10009, table_name='energy_usage') + b_slc = self.create_slice( + 'b_slc', id=10010, table_name='birth_names') + c_slc = self.create_slice('c_slc', id=10011, table_name='birth_names') + dash_to_import_override = self.create_dashboard( + 'override_dashboard_new', slcs=[e_slc, b_slc, c_slc], id=10004) + imported_dash_id_2 = models.Dashboard.import_obj( + dash_to_import_override, import_time=1992) + + # override doesn't change the id + self.assertEquals(imported_dash_id_1, imported_dash_id_2) + expected_dash = self.create_dashboard( + 'override_dashboard_new', slcs=[e_slc, b_slc, c_slc], id=10004) + make_transient(expected_dash) + imported_dash = self.get_dash(imported_dash_id_2) + self.assert_dash_equals(expected_dash, imported_dash) + self.assertEquals({"remote_id": 10004, "import_time": 1992}, + json.loads(imported_dash.json_metadata)) + + def test_import_table_no_metadata(self): + table = self.create_table('pure_table', id=10001) + imported_t_id = models.SqlaTable.import_obj(table, import_time=1989) + imported_table = self.get_table(imported_t_id) + self.assert_table_equals(table, imported_table) + + def test_import_table_1_col_1_met(self): + table = self.create_table( + 'table_1_col_1_met', id=10002, + cols_names=["col1"], metric_names=["metric1"]) + imported_t_id = models.SqlaTable.import_obj(table, import_time=1990) + imported_table = self.get_table(imported_t_id) + self.assert_table_equals(table, imported_table) + self.assertEquals( + {'remote_id': 10002, 'import_time': 1990, 'database_name': 'main'}, + json.loads(imported_table.params)) + + def test_import_table_2_col_2_met(self): + table = self.create_table( + 'table_2_col_2_met', id=10003, cols_names=['c1', 'c2'], + metric_names=['m1', 'm2']) + imported_t_id = models.SqlaTable.import_obj(table, import_time=1991) + + imported_table = self.get_table(imported_t_id) + self.assert_table_equals(table, imported_table) + + def test_import_table_override(self): + table = self.create_table( + 'table_override', id=10003, cols_names=['col1'], + metric_names=['m1']) + imported_t_id = models.SqlaTable.import_obj(table, import_time=1991) + + table_over = self.create_table( + 'table_override', id=10003, cols_names=['new_col1', 'col2', 'col3'], + metric_names=['new_metric1']) + imported_table_over_id = models.SqlaTable.import_obj( + table_over, import_time=1992) + + imported_table_over = self.get_table(imported_table_over_id) + self.assertEquals(imported_t_id, imported_table_over.id) + expected_table = self.create_table( + 'table_override', id=10003, metric_names=['new_metric1', 'm1'], + cols_names=['col1', 'new_col1', 'col2', 'col3']) + self.assert_table_equals(expected_table, imported_table_over) + + def test_import_table_override_idential(self): + table = self.create_table( + 'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'], + metric_names=['new_metric1']) + imported_t_id = models.SqlaTable.import_obj(table, import_time=1993) + + copy_table = self.create_table( + 'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'], + metric_names=['new_metric1']) + imported_t_id_copy = models.SqlaTable.import_obj( + copy_table, import_time=1994) + + self.assertEquals(imported_t_id, imported_t_id_copy) + self.assert_table_equals(copy_table, self.get_table(imported_t_id)) + +if __name__ == '__main__': + unittest.main()