From 85ad59b78ba54d01507ca7527e3d307ab815a82e Mon Sep 17 00:00:00 2001 From: hhyo Date: Fri, 3 Aug 2018 19:28:40 +0800 Subject: [PATCH] =?UTF-8?q?=E8=84=B1=E6=95=8F=E7=BB=86=E8=8A=82=E8=B0=83?= =?UTF-8?q?=E6=95=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sql/data_masking.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/sql/data_masking.py b/sql/data_masking.py index abb76cc8..7dec2149 100644 --- a/sql/data_masking.py +++ b/sql/data_masking.py @@ -149,10 +149,13 @@ def analy_query_tree(self, query_tree, cluster_name): for select_item in select_list: if select_item['type'] not in ('FIELD_ITEM', 'aggregate'): raise Exception('不支持该查询语句脱敏!') + if select_item['type'] == 'aggregate': + if select_item['aggregate'].get('type') != 'FIELD_ITEM': + raise Exception('不支持该查询语句脱敏!') # 获取select信息的规则,仅处理type为FIELD_ITEM和aggregate类型的select信息,如[*],[*,column_a],[column_a,*],[column_a,a.*,column_b],[a.*,column_a,b.*], select_index = [ - select_item['field'] if select_item['type'] == 'FIELD_ITEM' else select_item['aggregate']['field'] for + select_item['field'] if select_item['type'] == 'FIELD_ITEM' else select_item['aggregate'].get('field') for select_item in select_list if select_item['type'] in ('FIELD_ITEM', 'aggregate')] # 处理select_list,为统一的{'type': 'FIELD_ITEM', 'db': 'archer_master', 'table': 'sql_users', 'field': 'email'}格式 @@ -174,7 +177,7 @@ def analy_query_tree(self, query_tree, cluster_name): # 找出field不为* 的列信息, 循环判断列是否命中脱敏规则,并增加规则类型和index,index采取后切片 for index, item in enumerate(select_list): item['index'] = index - len(select_list) - if item['field'] != '*': + if item.get('field') != '*': columns.append(item) # [column_a, *] @@ -182,7 +185,7 @@ def analy_query_tree(self, query_tree, cluster_name): # 找出field不为* 的列信息, 循环判断列是否命中脱敏规则,并增加规则类型和index,index采取前切片 for index, item in enumerate(select_list): item['index'] = index - if item['field'] != '*': + if item.get('field') != '*': columns.append(item) # [column_a,a.*,column_b] @@ -190,23 +193,23 @@ def analy_query_tree(self, query_tree, cluster_name): # 找出field不为* 的列信息, 循环判断列是否命中脱敏规则,并增加规则类型和index,*前面的字段index采取前切片,*后面的字段采取后切片 for index, item in enumerate(select_list): item['index'] = index - if item['field'] == '*': + if item.get('field') == '*': first_idx = index break select_list.reverse() for index, item in enumerate(select_list): item['index'] = index - if item['field'] == '*': + if item.get('field') == '*': last_idx = len(select_list) - index - 1 break select_list.reverse() for index, item in enumerate(select_list): - if item['field'] != '*' and index < first_idx: + if item.get('field') != '*' and index < first_idx: item['index'] = index - if item['field'] != '*' and index > last_idx: + if item.get('field') != '*' and index > last_idx: item['index'] = index - len(select_list) columns.append(item) @@ -218,13 +221,13 @@ def analy_query_tree(self, query_tree, cluster_name): else: for index, item in enumerate(select_list): item['index'] = index - if item['field'] != '*': + if item.get('field') != '*': columns.append(item) # 格式化命中的列信息 for column in columns: - hit_info = self.hit_column(DataMaskingColumnsOb, cluster_name, column['db'], column['table'], - column['field']) + hit_info = self.hit_column(DataMaskingColumnsOb, cluster_name, column.get('db'), column.get('table'), + column.get('field')) if hit_info['is_hit']: hit_info['index'] = column['index'] hit_columns.append(hit_info)