diff --git a/sql/data_masking.py b/sql/data_masking.py index dba9d274..05c92f71 100644 --- a/sql/data_masking.py +++ b/sql/data_masking.py @@ -145,9 +145,14 @@ def analy_query_tree(self, query_tree, cluster_name): hit_columns = [] # 命中列 table_hit_columns = [] # 涉及表命中的列 - # 获取select信息的规则,仅处理type为FIELD_ITEM的select信息,如[*],[*,column_a],[column_a,*],[column_a,a.*,column_b],[a.*,column_a,b.*], - select_index = [select_item['field'] for select_item in select_list if - select_item['type'] == 'FIELD_ITEM'] + # 获取select信息的规则,仅处理type为FIELD_ITEM和aggregate类型的select信息,如[*],[*,column_a],[column_a,*],[column_a,a.*,column_b],[a.*,column_a,b.*], + select_index = [ + select_item['field'] if select_item['type'] == 'FIELD_ITEM' else select_item['aggregate']['field'] for + select_item in select_list if select_item['type'] in ('FIELD_ITEM', 'aggregate')] + + # 处理select_list,为统一的{'type': 'FIELD_ITEM', 'db': 'archer_master', 'table': 'sql_users', 'field': 'email'}格式 + select_list = [select_item if select_item['type'] == 'FIELD_ITEM' else select_item['aggregate'] for + select_item in select_list if select_item['type'] in ('FIELD_ITEM', 'aggregate')] if select_index: # 如果发现存在field='*',则遍历所有表,找出所有的命中字段 @@ -163,86 +168,63 @@ def analy_query_tree(self, query_tree, cluster_name): elif re.match(r"^(\*,)+(\w,?)+$", ','.join(select_index)): # 找出field不为* 的列信息, 循环判断列是否命中脱敏规则,并增加规则类型和index,index采取后切片 for index, item in enumerate(select_list): - if item['type'] == 'FIELD_ITEM': - item['index'] = index - len(select_list) - if item['field'] != '*': - columns.append(item) - - for column in columns: - hit_info = self.hit_column(DataMaskingColumnsOb, cluster_name, column['db'], - column['table'], column['field']) - if hit_info['is_hit']: - hit_info['index'] = column['index'] - hit_columns.append(hit_info) + item['index'] = index - len(select_list) + if item['field'] != '*': + columns.append(item) + # [column_a, *] elif re.match(r"^(\w,?)+(\*,?)+$", ','.join(select_index)): # 找出field不为* 的列信息, 循环判断列是否命中脱敏规则,并增加规则类型和index,index采取前切片 for index, item in enumerate(select_list): - if item['type'] == 'FIELD_ITEM': - item['index'] = index - if item['field'] != '*': - columns.append(item) - - for column in columns: - hit_info = self.hit_column(DataMaskingColumnsOb, cluster_name, column['db'], - column['table'], column['field']) - if hit_info['is_hit']: - hit_info['index'] = column['index'] - hit_columns.append(hit_info) + item['index'] = index + if item['field'] != '*': + columns.append(item) + # [column_a,a.*,column_b] elif re.match(r"^(\w,?)+(\*,?)+(\w,?)+$", ','.join(select_index)): # 找出field不为* 的列信息, 循环判断列是否命中脱敏规则,并增加规则类型和index,*前面的字段index采取前切片,*后面的字段采取后切片 for index, item in enumerate(select_list): - if item['type'] == 'FIELD_ITEM': - item['index'] = index - if item['field'] == '*': - first_idx = index - break + item['index'] = index + if item['field'] == '*': + first_idx = index + break select_list.reverse() for index, item in enumerate(select_list): - if item['type'] == 'FIELD_ITEM': - item['index'] = index - if item['field'] == '*': - last_idx = len(select_list) - index - 1 - break + item['index'] = index + if item['field'] == '*': + last_idx = len(select_list) - index - 1 + break select_list.reverse() for index, item in enumerate(select_list): - if item['type'] == 'FIELD_ITEM': - if item['field'] != '*' and index < first_idx: - item['index'] = index - columns.append(item) - - if item['field'] != '*' and index > last_idx: - item['index'] = index - len(select_list) - columns.append(item) - - for column in columns: - hit_info = self.hit_column(DataMaskingColumnsOb, cluster_name, column['db'], - column['table'], column['field']) - if hit_info['is_hit']: - hit_info['index'] = column['index'] - hit_columns.append(hit_info) + if item['field'] != '*' and index < first_idx: + item['index'] = index + + if item['field'] != '*' and index > last_idx: + item['index'] = index - len(select_list) + columns.append(item) # [a.*, column_a, b.*] else: - hit_columns = [] - return table_hit_columns, hit_columns + raise Exception('不支持select信息为[a.*, column_a, b.*]格式的查询脱敏!') + # 没有*的查询,直接遍历查询命中字段,query_tree的列index就是查询语句列的index else: for index, item in enumerate(select_list): - if item['type'] == 'FIELD_ITEM': - item['index'] = index - if item['field'] != '*': - columns.append(item) + item['index'] = index + if item['field'] != '*': + columns.append(item) + else: + raise Exception('不支持select信息包含函数的查询脱敏!') - for column in columns: - hit_info = self.hit_column(DataMaskingColumnsOb, cluster_name, column['db'], column['table'], - column['field']) - if hit_info['is_hit']: - hit_info['index'] = column['index'] - hit_columns.append(hit_info) + # 格式化命中的列信息 + for column in columns: + hit_info = self.hit_column(DataMaskingColumnsOb, cluster_name, column['db'], column['table'], + column['field']) + if hit_info['is_hit']: + hit_info['index'] = column['index'] + hit_columns.append(hit_info) return table_hit_columns, hit_columns # 判断字段是否命中脱敏规则,如果命中则返回脱敏的规则id和规则类型