From 09f2d92bfd570700daf041eafe2f026bd76f9fc2 Mon Sep 17 00:00:00 2001 From: Kristen Armes <6732445+kristenarmes@users.noreply.github.com> Date: Tue, 16 Aug 2022 11:32:12 -0700 Subject: [PATCH] perf: New neo4j csv publisher to improve performance using batched params (#1957) * New publisher using managed transactions, unwind statements, configurable two or one way relations Signed-off-by: Kristen Armes * Add logic for preserving adhoc ui data, and move write transactions logic to reusable function Signed-off-by: Kristen Armes * Fixing flake8 error (not sure how it got through, not from my change) Signed-off-by: Kristen Armes * Pulling in a few of the latest changes from the original publisher and lint Signed-off-by: Kristen Armes * Fix tests Signed-off-by: Kristen Armes * Addressing PR feedback and change index creation to use a managed transaction Signed-off-by: Kristen Armes * Refactor props body function and separate out constants Signed-off-by: Kristen Armes * Addressing PR feedback plus refactoring Signed-off-by: Kristen Armes * Addressing PR feedback, more refactoring Signed-off-by: Kristen Armes * Minor fixes and bump version Signed-off-by: Kristen Armes Signed-off-by: Kristen Armes --- .../hive_table_last_updated_extractor.py | 4 +- .../publisher/neo4j_csv_publisher.py | 35 +- .../publisher/neo4j_csv_unwind_publisher.py | 346 ++++++++++++++++++ .../publisher/publisher_config_constants.py | 62 ++++ .../databuilder/utils/publisher_utils.py | 119 ++++++ databuilder/setup.py | 2 +- .../test_neo4j_csv_unwind_publisher.py | 74 ++++ 7 files changed, 623 insertions(+), 19 deletions(-) create mode 100644 databuilder/databuilder/publisher/neo4j_csv_unwind_publisher.py create mode 100644 databuilder/databuilder/publisher/publisher_config_constants.py create mode 100644 databuilder/databuilder/utils/publisher_utils.py create mode 100644 databuilder/tests/unit/publisher/test_neo4j_csv_unwind_publisher.py diff --git a/databuilder/databuilder/extractor/hive_table_last_updated_extractor.py b/databuilder/databuilder/extractor/hive_table_last_updated_extractor.py index 68d4e9b3a9..4c718ded13 100644 --- a/databuilder/databuilder/extractor/hive_table_last_updated_extractor.py +++ b/databuilder/databuilder/extractor/hive_table_last_updated_extractor.py @@ -113,8 +113,8 @@ class HiveTableLastUpdatedExtractor(Extractor): AND NOT EXISTS (SELECT * FROM PARTITION_KEYS WHERE PARTITION_KEYS.TBL_ID = TBLS.TBL_ID) """ - DEFAULT_POSTGRES_ADDTIONAL_WHERE_CLAUSE = """ NOT EXISTS (SELECT * FROM "PARTITIONS" p WHERE p."TBL_ID" = t."TBL_ID") - AND NOT EXISTS (SELECT * FROM "PARTITION_KEYS" pk WHERE pk."TBL_ID" = t."TBL_ID") + DEFAULT_POSTGRES_ADDTIONAL_WHERE_CLAUSE = """ NOT EXISTS (SELECT * FROM "PARTITIONS" p + WHERE p."TBL_ID" = t."TBL_ID") AND NOT EXISTS (SELECT * FROM "PARTITION_KEYS" pk WHERE pk."TBL_ID" = t."TBL_ID") """ DATABASE = 'hive' diff --git a/databuilder/databuilder/publisher/neo4j_csv_publisher.py b/databuilder/databuilder/publisher/neo4j_csv_publisher.py index fb9f8c1f73..e0a99a3f78 100644 --- a/databuilder/databuilder/publisher/neo4j_csv_publisher.py +++ b/databuilder/databuilder/publisher/neo4j_csv_publisher.py @@ -24,6 +24,9 @@ from databuilder.publisher.base_publisher import Publisher from databuilder.publisher.neo4j_preprocessor import NoopRelationPreprocessor +from databuilder.publisher.publisher_config_constants import ( + Neo4jCsvPublisherConfigs, PublishBehaviorConfigs, PublisherConfigs, +) # Setting field_size_limit to solve the error below # _csv.Error: field larger than field limit (131072) @@ -32,53 +35,53 @@ # Config keys # A directory that contains CSV files for nodes -NODE_FILES_DIR = 'node_files_directory' +NODE_FILES_DIR = PublisherConfigs.NODE_FILES_DIR # A directory that contains CSV files for relationships -RELATION_FILES_DIR = 'relation_files_directory' +RELATION_FILES_DIR = PublisherConfigs.RELATION_FILES_DIR # A end point for Neo4j e.g: bolt://localhost:9999 -NEO4J_END_POINT_KEY = 'neo4j_endpoint' +NEO4J_END_POINT_KEY = Neo4jCsvPublisherConfigs.NEO4J_END_POINT_KEY # A transaction size that determines how often it commits. -NEO4J_TRANSACTION_SIZE = 'neo4j_transaction_size' +NEO4J_TRANSACTION_SIZE = Neo4jCsvPublisherConfigs.NEO4J_TRANSACTION_SIZE # A progress report frequency that determines how often it report the progress. NEO4J_PROGRESS_REPORT_FREQUENCY = 'neo4j_progress_report_frequency' # A boolean flag to make it fail if relationship is not created NEO4J_RELATIONSHIP_CREATION_CONFIRM = 'neo4j_relationship_creation_confirm' -NEO4J_MAX_CONN_LIFE_TIME_SEC = 'neo4j_max_conn_life_time_sec' +NEO4J_MAX_CONN_LIFE_TIME_SEC = Neo4jCsvPublisherConfigs.NEO4J_MAX_CONN_LIFE_TIME_SEC # list of nodes that are create only, and not updated if match exists -NEO4J_CREATE_ONLY_NODES = 'neo4j_create_only_nodes' +NEO4J_CREATE_ONLY_NODES = Neo4jCsvPublisherConfigs.NEO4J_CREATE_ONLY_NODES # list of node labels that could attempt to be accessed simultaneously NEO4J_DEADLOCK_NODE_LABELS = 'neo4j_deadlock_node_labels' -NEO4J_USER = 'neo4j_user' -NEO4J_PASSWORD = 'neo4j_password' +NEO4J_USER = Neo4jCsvPublisherConfigs.NEO4J_USER +NEO4J_PASSWORD = Neo4jCsvPublisherConfigs.NEO4J_PASSWORD # in Neo4j (v4.0+), we can create and use more than one active database at the same time -NEO4J_DATABASE_NAME = 'neo4j_database' +NEO4J_DATABASE_NAME = Neo4jCsvPublisherConfigs.NEO4J_DATABASE_NAME # NEO4J_ENCRYPTED is a boolean indicating whether to use SSL/TLS when connecting -NEO4J_ENCRYPTED = 'neo4j_encrypted' +NEO4J_ENCRYPTED = Neo4jCsvPublisherConfigs.NEO4J_ENCRYPTED # NEO4J_VALIDATE_SSL is a boolean indicating whether to validate the server's SSL/TLS # cert against system CAs -NEO4J_VALIDATE_SSL = 'neo4j_validate_ssl' +NEO4J_VALIDATE_SSL = Neo4jCsvPublisherConfigs.NEO4J_VALIDATE_SSL # This will be used to provide unique tag to the node and relationship -JOB_PUBLISH_TAG = 'job_publish_tag' +JOB_PUBLISH_TAG = PublisherConfigs.JOB_PUBLISH_TAG # any additional fields that should be added to nodes and rels through config -ADDITIONAL_FIELDS = 'additional_fields' +ADDITIONAL_FIELDS = PublisherConfigs.ADDITIONAL_PUBLISHER_METADATA_FIELDS # Neo4j property name for published tag -PUBLISHED_TAG_PROPERTY_NAME = 'published_tag' +PUBLISHED_TAG_PROPERTY_NAME = PublisherConfigs.PUBLISHED_TAG_PROPERTY_NAME # Neo4j property name for last updated timestamp -LAST_UPDATED_EPOCH_MS = 'publisher_last_updated_epoch_ms' +LAST_UPDATED_EPOCH_MS = PublisherConfigs.LAST_UPDATED_EPOCH_MS # A boolean flag to indicate if publisher_metadata (e.g. published_tag, # publisher_last_updated_epoch_ms) # will be included as properties of the Neo4j nodes -ADD_PUBLISHER_METADATA = 'add_publisher_metadata' +ADD_PUBLISHER_METADATA = PublishBehaviorConfigs.ADD_PUBLISHER_METADATA RELATION_PREPROCESSOR = 'relation_preprocessor' diff --git a/databuilder/databuilder/publisher/neo4j_csv_unwind_publisher.py b/databuilder/databuilder/publisher/neo4j_csv_unwind_publisher.py new file mode 100644 index 0000000000..f83c2a3d15 --- /dev/null +++ b/databuilder/databuilder/publisher/neo4j_csv_unwind_publisher.py @@ -0,0 +1,346 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import csv +import ctypes +import logging +import time +from io import open +from typing import ( + Dict, List, Set, +) + +import neo4j +import pandas +from jinja2 import Template +from neo4j import GraphDatabase, Neo4jDriver +from neo4j.api import ( + SECURITY_TYPE_SECURE, SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, parse_neo4j_uri, +) +from pyhocon import ConfigFactory, ConfigTree + +from databuilder.models.graph_serializable import ( + NODE_KEY, NODE_LABEL, RELATION_END_KEY, RELATION_END_LABEL, RELATION_REVERSE_TYPE, RELATION_START_KEY, + RELATION_START_LABEL, RELATION_TYPE, +) +from databuilder.publisher.base_publisher import Publisher +from databuilder.publisher.publisher_config_constants import ( + Neo4jCsvPublisherConfigs, PublishBehaviorConfigs, PublisherConfigs, +) +from databuilder.utils.publisher_utils import ( + chunkify_list, create_neo4j_node_key_constraint, create_props_param, execute_neo4j_statement, get_props_body_keys, + list_files, +) + +# Setting field_size_limit to solve the error below +# _csv.Error: field larger than field limit (131072) +# https://stackoverflow.com/a/54517228/5972935 +csv.field_size_limit(int(ctypes.c_ulong(-1).value // 2)) + +# Required columns for Node +NODE_REQUIRED_KEYS = {NODE_LABEL, NODE_KEY} +# Required columns for Relationship +RELATION_REQUIRED_KEYS = {RELATION_START_LABEL, RELATION_START_KEY, + RELATION_END_LABEL, RELATION_END_KEY, + RELATION_TYPE, RELATION_REVERSE_TYPE} + +DEFAULT_CONFIG = ConfigFactory.from_dict({Neo4jCsvPublisherConfigs.NEO4J_TRANSACTION_SIZE: 1000, + Neo4jCsvPublisherConfigs.NEO4J_MAX_CONN_LIFE_TIME_SEC: 50, + Neo4jCsvPublisherConfigs.NEO4J_DATABASE_NAME: neo4j.DEFAULT_DATABASE, + PublishBehaviorConfigs.ADD_PUBLISHER_METADATA: True, + PublishBehaviorConfigs.PUBLISH_REVERSE_RELATIONSHIPS: True, + PublishBehaviorConfigs.PRESERVE_ADHOC_UI_DATA: True}) + +LOGGER = logging.getLogger(__name__) + + +class Neo4jCsvUnwindPublisher(Publisher): + """ + This publisher takes two folders for input and publishes to Neo4j. + One folder will contain CSV file(s) for Nodes where the other folder will contain CSV + file(s) for Relationships. + + The merge statements make use of the UNWIND clause to allow for batched params to be applied to each + statement. This improves performance by reducing the amount of individual transactions to the database, + and by allowing Neo4j to compile and cache the statement. + """ + + def init(self, conf: ConfigTree) -> None: + conf = conf.with_fallback(DEFAULT_CONFIG) + + self._count: int = 0 + self._node_files = list_files(conf, PublisherConfigs.NODE_FILES_DIR) + self._node_files_iter = iter(self._node_files) + + self._relation_files = list_files(conf, PublisherConfigs.RELATION_FILES_DIR) + self._relation_files_iter = iter(self._relation_files) + + self._driver = self._driver_init(conf) + self._db_name = conf.get_string(Neo4jCsvPublisherConfigs.NEO4J_DATABASE_NAME) + self._transaction_size = conf.get_int(Neo4jCsvPublisherConfigs.NEO4J_TRANSACTION_SIZE) + + # config is list of node label. + # When set, this list specifies a list of nodes that shouldn't be updated, if exists + self._create_only_nodes = set(conf.get_list(Neo4jCsvPublisherConfigs.NEO4J_CREATE_ONLY_NODES, default=[])) + self._labels: Set[str] = set() + self._publish_tag: str = conf.get_string(PublisherConfigs.JOB_PUBLISH_TAG) + self._additional_publisher_metadata_fields: Dict =\ + dict(conf.get(PublisherConfigs.ADDITIONAL_PUBLISHER_METADATA_FIELDS, default={})) + self._add_publisher_metadata: bool = conf.get_bool(PublishBehaviorConfigs.ADD_PUBLISHER_METADATA) + self._publish_reverse_relationships: bool = conf.get_bool(PublishBehaviorConfigs.PUBLISH_REVERSE_RELATIONSHIPS) + self._preserve_adhoc_ui_data = conf.get_bool(PublishBehaviorConfigs.PRESERVE_ADHOC_UI_DATA) + if self._add_publisher_metadata and not self._publish_tag: + raise Exception(f'{PublisherConfigs.JOB_PUBLISH_TAG} should not be empty') + + LOGGER.info('Publishing Node csv files %s, and Relation CSV files %s', + self._node_files, + self._relation_files) + + def _driver_init(self, conf: ConfigTree) -> Neo4jDriver: + uri = conf.get_string(Neo4jCsvPublisherConfigs.NEO4J_END_POINT_KEY) + driver_args = { + 'uri': uri, + 'max_connection_lifetime': conf.get_int(Neo4jCsvPublisherConfigs.NEO4J_MAX_CONN_LIFE_TIME_SEC), + 'auth': (conf.get_string(Neo4jCsvPublisherConfigs.NEO4J_USER), + conf.get_string(Neo4jCsvPublisherConfigs.NEO4J_PASSWORD)), + } + + # if URI scheme not secure set `trust`` and `encrypted` to default values + # https://neo4j.com/docs/api/python-driver/current/api.html#uri + _, security_type, _ = parse_neo4j_uri(uri=uri) + if security_type not in [SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, SECURITY_TYPE_SECURE]: + default_security_conf = {'trust': neo4j.TRUST_ALL_CERTIFICATES, 'encrypted': True} + driver_args.update(default_security_conf) + + # if NEO4J_VALIDATE_SSL or NEO4J_ENCRYPTED are set in config pass them to the driver + validate_ssl_conf = conf.get(Neo4jCsvPublisherConfigs.NEO4J_VALIDATE_SSL, None) + encrypted_conf = conf.get(Neo4jCsvPublisherConfigs.NEO4J_ENCRYPTED, None) + if validate_ssl_conf is not None: + driver_args['trust'] = neo4j.TRUST_SYSTEM_CA_SIGNED_CERTIFICATES if validate_ssl_conf \ + else neo4j.TRUST_ALL_CERTIFICATES + if encrypted_conf is not None: + driver_args['encrypted'] = encrypted_conf + + driver = GraphDatabase.driver(**driver_args) + + try: + driver.verify_connectivity() + except Exception as e: + driver.close() + raise e + + return driver + + def publish_impl(self) -> None: # noqa: C901 + """ + Publishes Nodes first and then Relations + """ + start = time.time() + + for node_file in self._node_files: + self.pre_publish_node_file(node_file) + + LOGGER.info('Publishing Node files: %s', self._node_files) + while True: + try: + node_file = next(self._node_files_iter) + self._publish_node_file(node_file) + except StopIteration: + break + + for rel_file in self._relation_files: + self.pre_publish_rel_file(rel_file) + + LOGGER.info('Publishing Relationship files: %s', self._relation_files) + while True: + try: + relation_file = next(self._relation_files_iter) + self._publish_relation_file(relation_file) + except StopIteration: + break + + LOGGER.info('Committed total %i statements', self._count) + + # TODO: Add statsd support + LOGGER.info('Successfully published. Elapsed: %i seconds', time.time() - start) + + def get_scope(self) -> str: + return 'publisher.neo4j' + + # Can be overridden with custom action(s) + def pre_publish_node_file(self, node_file: str) -> None: + created_constraint_labels = create_neo4j_node_key_constraint(node_file, self._labels, + self._driver, self._db_name) + self._labels.union(created_constraint_labels) + + # Can be overridden with custom action(s) + def pre_publish_rel_file(self, rel_file: str) -> None: + pass + + def _publish_node_file(self, node_file: str) -> None: + with open(node_file, 'r', encoding='utf8') as node_csv: + csv_dataframe = pandas.read_csv(node_csv, na_filter=False) + all_node_records = csv_dataframe.to_dict(orient="records") + + # Get the first node label since they will be the same for all records in the file + merge_stmt = self._create_node_merge_statement(node_keys=csv_dataframe.columns.tolist(), + node_label=all_node_records[0][NODE_LABEL]) + + self._write_transactions(merge_stmt, all_node_records) + + def _create_node_merge_statement(self, node_keys: list, node_label: str) -> str: + template = Template(""" + UNWIND $batch AS row + MERGE (node:{{ LABEL }} {key: row.KEY}) + ON CREATE SET {{ PROPS_BODY_CREATE }} + {% if update %} ON MATCH SET {{ PROPS_BODY_UPDATE }} {% endif %} + """) + + props_body_create = self._create_props_body(get_props_body_keys(node_keys, + NODE_REQUIRED_KEYS, + self._additional_publisher_metadata_fields), 'node') + + props_body_update = props_body_create + if self._preserve_adhoc_ui_data: + props_body_update = self._create_props_body(get_props_body_keys(node_keys, + NODE_REQUIRED_KEYS, + self._additional_publisher_metadata_fields), + 'node', True) + + return template.render(LABEL=node_label, + PROPS_BODY_CREATE=props_body_create, + PROPS_BODY_UPDATE=props_body_update, + update=(node_label not in self._create_only_nodes)) + + def _publish_relation_file(self, relation_file: str) -> None: + with open(relation_file, 'r', encoding='utf8') as relation_csv: + csv_dataframe = pandas.read_csv(relation_csv, na_filter=False) + all_rel_records = csv_dataframe.to_dict(orient="records") + + # Get the first relation labels since they will be the same for all records in the file + merge_stmt = self._create_relationship_merge_statement( + rel_keys=csv_dataframe.columns.tolist(), + start_label=all_rel_records[0][RELATION_START_LABEL], + end_label=all_rel_records[0][RELATION_END_LABEL], + relation_type=all_rel_records[0][RELATION_TYPE], + relation_reverse_type=all_rel_records[0][RELATION_REVERSE_TYPE] + ) + + self._write_transactions(merge_stmt, all_rel_records) + + def _create_relationship_merge_statement(self, + rel_keys: list, + start_label: str, + end_label: str, + relation_type: str, + relation_reverse_type: str) -> str: + template = Template(""" + UNWIND $batch as row + MATCH (n1:{{ START_LABEL }} {key: row.START_KEY}), (n2:{{ END_LABEL }} {key: row.END_KEY}) + {% if publish_reverse_relationships %} + MERGE (n1)-[r1:{{ TYPE }}]->(n2)-[r2:{{ REVERSE_TYPE }}]->(n1) + {% else %} + MERGE (n1)-[r1:{{ TYPE }}]->(n2) + {% endif %} + {% if update_props_body %} + ON CREATE SET {{ props_body_create }} + ON MATCH SET {{ props_body_update }} + {% endif %} + RETURN n1.key, n2.key + """) + + props_body_template = Template("""{{ props_body_r1 }} , {{ props_body_r2 }}""") + + props_body_r1 = self._create_props_body(get_props_body_keys(rel_keys, + RELATION_REQUIRED_KEYS, + self._additional_publisher_metadata_fields), 'r1') + props_body_r2 = self._create_props_body(get_props_body_keys(rel_keys, + RELATION_REQUIRED_KEYS, + self._additional_publisher_metadata_fields), 'r2') + if self._publish_reverse_relationships: + props_body_create = props_body_template.render(props_body_r1=props_body_r1, props_body_r2=props_body_r2) + else: + props_body_create = props_body_r1 + + props_body_update = props_body_create + if self._preserve_adhoc_ui_data: + props_body_r1 = self._create_props_body(get_props_body_keys(rel_keys, + RELATION_REQUIRED_KEYS, + self._additional_publisher_metadata_fields), + 'r1', True) + props_body_r2 = self._create_props_body(get_props_body_keys(rel_keys, + RELATION_REQUIRED_KEYS, + self._additional_publisher_metadata_fields), + 'r2', True) + if self._publish_reverse_relationships: + props_body_update = props_body_template.render(props_body_r1=props_body_r1, props_body_r2=props_body_r2) + else: + props_body_update = props_body_r1 + + return template.render(START_LABEL=start_label, + END_LABEL=end_label, + publish_reverse_relationships=self._publish_reverse_relationships, + TYPE=relation_type, + REVERSE_TYPE=relation_reverse_type, + update_props_body=props_body_r1, + props_body_create=props_body_create, + props_body_update=props_body_update) + + def _create_props_body(self, + record_keys: Set, + identifier: str, + rename_id_to_preserve_ui_data: bool = False) -> str: + """ + Creates properties body with params required for resolving template. + + e.g: Note that node.key3 is not quoted if header has UNQUOTED_SUFFIX. + identifier.key1 = 'val1' , identifier.key2 = 'val2', identifier.key3 = val3 + + :param record_keys: a list of keys for a CSV row + :param identifier: identifier that will be used in CYPHER query as shown on above example + :param rename_id_to_preserve_ui_data: specifies whether to null out the identifier to prevent it from updating + :return: Properties body for Cypher statement + """ + # For SET, if the evaluated expression is null, no action is performed. I.e. `SET (null).foo = 5` is a noop. + # See https://neo4j.com/docs/cypher-manual/current/clauses/set/ + if rename_id_to_preserve_ui_data: + identifier = f""" + (CASE WHEN {identifier}.{PublisherConfigs.PUBLISHED_TAG_PROPERTY_NAME} IS NOT NULL + THEN {identifier} ELSE null END) + """ + + template = Template(""" + {% for k in record_keys %} + {{ identifier }}.{{ k }} = row.{{ k }} + {{ ", " if not loop.last else "" }} + {% endfor %} + {% if record_keys and add_publisher_metadata %} + , + {% endif %} + {% if add_publisher_metadata %} + {{ identifier }}.{{ published_tag_prop }} = '{{ publish_tag }}', + {{ identifier }}.{{ last_updated_prop }} = timestamp() + {% endif %} + """) + + props_body = template.render(record_keys=record_keys, + identifier=identifier, + add_publisher_metadata=self._add_publisher_metadata, + published_tag_prop=PublisherConfigs.PUBLISHED_TAG_PROPERTY_NAME, + publish_tag=self._publish_tag, + last_updated_prop=PublisherConfigs.LAST_UPDATED_EPOCH_MS) + return props_body.strip() + + def _write_transactions(self, + stmt: str, + records: List[dict]) -> None: + for chunk in chunkify_list(records, self._transaction_size): + params_list = [] + for record in chunk: + params_list.append(create_props_param(record, self._additional_publisher_metadata_fields)) + + with self._driver.session(database=self._db_name) as session: + session.write_transaction(execute_neo4j_statement, stmt, {'batch': params_list}) + + self._count += len(params_list) + LOGGER.info(f'Committed {self._count} rows so far') diff --git a/databuilder/databuilder/publisher/publisher_config_constants.py b/databuilder/databuilder/publisher/publisher_config_constants.py new file mode 100644 index 0000000000..be664697a0 --- /dev/null +++ b/databuilder/databuilder/publisher/publisher_config_constants.py @@ -0,0 +1,62 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +class PublisherConfigs: + # A directory that contains CSV files for nodes + NODE_FILES_DIR = 'node_files_directory' + # A directory that contains CSV files for relationships + RELATION_FILES_DIR = 'relation_files_directory' + + # A CSV header with this suffix will be passed to the statement without quotes + UNQUOTED_SUFFIX = ':UNQUOTED' + + # This will be used to provide unique tag to the node and relationship + JOB_PUBLISH_TAG = 'job_publish_tag' + + # any additional fields that should be added to nodes and rels through config + ADDITIONAL_PUBLISHER_METADATA_FIELDS = 'additional_publisher_metadata_fields' + + # Property name for published tag + PUBLISHED_TAG_PROPERTY_NAME = 'published_tag' + # Property name for last updated timestamp + LAST_UPDATED_EPOCH_MS = 'publisher_last_updated_epoch_ms' + + +class PublishBehaviorConfigs: + # A boolean flag to indicate if publisher_metadata (e.g. published_tag, + # publisher_last_updated_epoch_ms) + # will be included as properties of the nodes + ADD_PUBLISHER_METADATA = 'add_publisher_metadata' + + # NOTE: Do not use this unless you have a specific use case for it. Amundsen expects two way relationships, and + # the default value should be set to true to publish relations in both directions. If it is overridden and set + # to false, reverse relationships will not be published. + PUBLISH_REVERSE_RELATIONSHIPS = 'publish_reverse_relationships' + + # If enabled, stops the publisher from updating a node or relationship + # created via the UI, e.g. a description or owner added manually by an Amundsen user. + # Such nodes/relationships will not have a 'published_tag' property that is set by databuilder. + PRESERVE_ADHOC_UI_DATA = 'preserve_adhoc_ui_data' + + +class Neo4jCsvPublisherConfigs: + # A end point for Neo4j e.g: bolt://localhost:9999 + NEO4J_END_POINT_KEY = 'neo4j_endpoint' + # A transaction size that determines how often it commits. + NEO4J_TRANSACTION_SIZE = 'neo4j_transaction_size' + + NEO4J_MAX_CONN_LIFE_TIME_SEC = 'neo4j_max_conn_life_time_sec' + + # list of nodes that are create only, and not updated if match exists + NEO4J_CREATE_ONLY_NODES = 'neo4j_create_only_nodes' + + NEO4J_USER = 'neo4j_user' + NEO4J_PASSWORD = 'neo4j_password' + # in Neo4j (v4.0+), we can create and use more than one active database at the same time + NEO4J_DATABASE_NAME = 'neo4j_database' + + # NEO4J_ENCRYPTED is a boolean indicating whether to use SSL/TLS when connecting + NEO4J_ENCRYPTED = 'neo4j_encrypted' + # NEO4J_VALIDATE_SSL is a boolean indicating whether to validate the server's SSL/TLS + # cert against system CAs + NEO4J_VALIDATE_SSL = 'neo4j_validate_ssl' diff --git a/databuilder/databuilder/utils/publisher_utils.py b/databuilder/databuilder/utils/publisher_utils.py new file mode 100644 index 0000000000..08b8f8aeb3 --- /dev/null +++ b/databuilder/databuilder/utils/publisher_utils.py @@ -0,0 +1,119 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +from os import listdir +from os.path import isfile, join +from typing import ( + Iterator, List, Set, +) + +import pandas +from jinja2 import Template +from neo4j import Neo4jDriver, Transaction +from neo4j.exceptions import Neo4jError +from pyhocon import ConfigTree + +from databuilder.models.graph_serializable import NODE_LABEL +from databuilder.publisher.publisher_config_constants import PublisherConfigs + +LOGGER = logging.getLogger(__name__) + + +def chunkify_list(records: List[dict], chunk_size: int) -> Iterator[List[dict]]: + """ + Generator to evenly split the input list into chunks + """ + for index in range(0, len(records), chunk_size): + yield records[index:index + chunk_size] + + +def create_neo4j_node_key_constraint(node_file: str, + current_labels: Set, + driver: Neo4jDriver, + db_name: str) -> Set: + """ + Go over the node file and try creating unique indices. + For any label seen first time for this publisher it will try to create unique index. + Neo4j ignores a second creation in 3.x, but raises an error in 4.x. + """ + LOGGER.info('Creating indices using Node file: %s. (Existing indices will be ignored)', node_file) + + labels = set(current_labels) + with open(node_file, 'r', encoding='utf8') as node_csv: + for node_record in pandas.read_csv(node_csv, + na_filter=False).to_dict(orient='records'): + label = node_record[NODE_LABEL] + if label not in labels: + with driver.session(database=db_name) as session: + try: + create_stmt = Template(""" + CREATE CONSTRAINT ON (node:{{ LABEL }}) ASSERT node.key IS UNIQUE + """).render(LABEL=label) + + LOGGER.info(f'Trying to create index for label {label} if not exist: {create_stmt}') + + session.write_transaction(execute_neo4j_statement, create_stmt) + except Neo4jError as e: + if 'An equivalent constraint already exists' not in e.__str__(): + raise + # Else, swallow the exception, to make this function idempotent. + labels.add(label) + + LOGGER.info('Indices have been created.') + return labels + + +def create_props_param(record_dict: dict, additional_publisher_metadata_fields: dict) -> dict: + """ + Create a dict of all the params for a given record + """ + params = {} + + for k, v in {**record_dict, **additional_publisher_metadata_fields}.items(): + params[strip_unquoted_suffix(k)] = v + + return params + + +def execute_neo4j_statement(tx: Transaction, + stmt: str, + params: dict = None) -> None: + """ + Executes statement against Neo4j. If execution fails, it rollsback and raises exception. + """ + LOGGER.debug('Executing statement: %s with params %s', stmt, params) + + tx.run(stmt, parameters=params) + + +def get_props_body_keys(record_keys: list, + exclude_keys: Set, + additional_publisher_metadata_fields: dict) -> Set: + """ + Returns the set of keys to be used in the props body of the merge statements + :param record_keys: + :param exclude_keys: set of excluded columns that do not need to be in properties (e.g: KEY, LABEL ...) + :param additional_publisher_metadata_fields: + """ + props_body_keys = set(record_keys) - exclude_keys + formatted_keys = map(strip_unquoted_suffix, props_body_keys) + return set(formatted_keys).union(additional_publisher_metadata_fields.keys()) + + +def list_files(conf: ConfigTree, path_key: str) -> List[str]: + """ + List files from directory + :param conf: + :param path_key: + :return: List of file paths + """ + if path_key not in conf: + return [] + + path = conf.get_string(path_key) + return [join(path, f) for f in listdir(path) if isfile(join(path, f))] + + +def strip_unquoted_suffix(key: str) -> str: + return key[:-len(PublisherConfigs.UNQUOTED_SUFFIX)] if key.endswith(PublisherConfigs.UNQUOTED_SUFFIX) else key diff --git a/databuilder/setup.py b/databuilder/setup.py index b589047239..0919896126 100644 --- a/databuilder/setup.py +++ b/databuilder/setup.py @@ -5,7 +5,7 @@ from setuptools import find_packages, setup -__version__ = '7.1.2' +__version__ = '7.2.0' requirements_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'requirements.txt') diff --git a/databuilder/tests/unit/publisher/test_neo4j_csv_unwind_publisher.py b/databuilder/tests/unit/publisher/test_neo4j_csv_unwind_publisher.py new file mode 100644 index 0000000000..230319a0c4 --- /dev/null +++ b/databuilder/tests/unit/publisher/test_neo4j_csv_unwind_publisher.py @@ -0,0 +1,74 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os +import unittest +import uuid + +from mock import MagicMock, patch +from neo4j import GraphDatabase +from pyhocon import ConfigFactory + +from databuilder.publisher.neo4j_csv_unwind_publisher import Neo4jCsvUnwindPublisher +from databuilder.publisher.publisher_config_constants import Neo4jCsvPublisherConfigs, PublisherConfigs + +here = os.path.dirname(__file__) + + +class TestPublish(unittest.TestCase): + + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + self._resource_path = os.path.join(here, '../resources/csv_publisher') + + def test_publisher_write_exception(self) -> None: + with patch.object(GraphDatabase, 'driver') as mock_driver: + mock_session = MagicMock() + mock_driver.return_value.session.return_value = mock_session + + mock_write_transaction = MagicMock(side_effect=Exception('Could not write')) + mock_session.__enter__.return_value.write_transaction = mock_write_transaction + + publisher = Neo4jCsvUnwindPublisher() + + conf = ConfigFactory.from_dict( + {Neo4jCsvPublisherConfigs.NEO4J_END_POINT_KEY: 'bolt://999.999.999.999:7687/', + PublisherConfigs.NODE_FILES_DIR: f'{self._resource_path}/nodes', + PublisherConfigs.RELATION_FILES_DIR: f'{self._resource_path}/relations', + Neo4jCsvPublisherConfigs.NEO4J_USER: 'neo4j_user', + Neo4jCsvPublisherConfigs.NEO4J_PASSWORD: 'neo4j_password', + PublisherConfigs.JOB_PUBLISH_TAG: str(uuid.uuid4())} + ) + publisher.init(conf) + + with self.assertRaises(Exception): + publisher.publish() + + def test_publisher(self) -> None: + with patch.object(GraphDatabase, 'driver') as mock_driver: + mock_session = MagicMock() + mock_driver.return_value.session.return_value = mock_session + + mock_write_transaction = MagicMock() + mock_session.__enter__.return_value.write_transaction = mock_write_transaction + + publisher = Neo4jCsvUnwindPublisher() + + conf = ConfigFactory.from_dict( + {Neo4jCsvPublisherConfigs.NEO4J_END_POINT_KEY: 'bolt://999.999.999.999:7687/', + PublisherConfigs.NODE_FILES_DIR: f'{self._resource_path}/nodes', + PublisherConfigs.RELATION_FILES_DIR: f'{self._resource_path}/relations', + Neo4jCsvPublisherConfigs.NEO4J_USER: 'neo4j_user', + Neo4jCsvPublisherConfigs.NEO4J_PASSWORD: 'neo4j_password', + PublisherConfigs.JOB_PUBLISH_TAG: str(uuid.uuid4())} + ) + publisher.init(conf) + publisher.publish() + + # Create 2 indices, write 2 node files, write 1 relation file + self.assertEqual(5, mock_write_transaction.call_count) + + +if __name__ == '__main__': + unittest.main()