1
+ from collections import namedtuple
1
2
import os
2
3
from basepy .log import logger
3
4
@@ -39,30 +40,43 @@ class SplitByFunction(Function):
39
40
function_name = 'splitby'
40
41
41
42
def process (self , data_file , context ):
42
- def write_to_group (group_name , line ):
43
- if group_name not in group_file_savers :
43
+ def write_to_group (group_names , line ):
44
+ if group_names not in group_file_savers :
45
+ group_name = '-' .join (group_names )
44
46
dst_path = os .path .join (context .temp_dir , f'{ data_file .name } -group-{ group_name } .jsonl' )
45
47
file_saver = AtomicSaver (dst_path )
46
48
file_saver .setup ()
47
- group_file_savers [group_name ] = file_saver
48
- saver = group_file_savers [group_name ]
49
+ group_file_savers [group_names ] = file_saver
50
+ saver = group_file_savers [group_names ]
49
51
saver .part_file .write (line .encode ('utf-8' ))
50
52
saver .part_file .write (b'\n ' )
51
53
52
54
data_files = []
53
55
group_file_savers = {}
54
- split_key = self .args ['key' ][0 ]
56
+ split_keys = self .args ['key' ]
57
+ tags = self .args ['tags' ]
58
+ tags_with_group = {}
55
59
file_reader = DataFileReader (data_file .file_path )
56
60
for (data , line ) in file_reader .readlines ():
57
- group_name = data .get (split_key )
58
- if not group_name :
61
+ group_names = []
62
+ for split_key in split_keys :
63
+ group_name = data .get (split_key )
64
+ group_names .append (group_name )
65
+ group_names = tuple (group_names )
66
+ if not tags_with_group .get (group_names ):
67
+ object_name = namedtuple ("DataObject" , data .keys ())(* data .values ())
68
+ fill_tags = {}
69
+ for tag_k ,tag_v in tags .items ():
70
+ fill_tags [tag_k ] = tag_v .format (data = object_name )
71
+ tags_with_group [group_names ] = fill_tags
72
+ if not group_names :
59
73
# TODO: warning
60
74
continue
61
- write_to_group (group_name , line )
62
-
63
- for saver in group_file_savers . values ():
64
- saver . __exit__ ( None , None , None )
65
- data_files .append (context .create_data_file (file_path = saver .dest_path ))
75
+ write_to_group (group_names , line )
76
+ for group_names , saver in group_file_savers . items ():
77
+ saver . __exit__ ( None , None , None )
78
+ tags = tags_with_group [ group_names ] if tags_with_group [ group_names ] else None
79
+ data_files .append (context .create_data_file (file_path = saver .dest_path , tags = tags ))
66
80
return data_files
67
81
68
82
@@ -72,10 +86,16 @@ class SaveFunction(FunctionMultiMixin, Function):
72
86
def process (self , data_file , context ):
73
87
logger .debug ('save function process' , data_file = data_file .file_path )
74
88
location = self .args .get ('location' )
89
+ path_suffix = self .args .get ('path_suffix' )
90
+ if path_suffix :
91
+ path_suffix = path_suffix .format (** data_file .tags )
92
+ if not path_suffix .endswith ('/' ):
93
+ path_suffix = path_suffix + '/'
75
94
storage = context .get_storage (location )
76
95
if not storage :
77
96
raise Exception ('No storage defined.' )
78
- storage .save (data_file .basename , data_file .file_path )
97
+ key = path_suffix + data_file .basename if path_suffix else data_file .basename
98
+ storage .save (key , data_file .file_path )
79
99
return data_file
80
100
81
101
@@ -102,7 +122,7 @@ def process(self, data_file, context):
102
122
file_saver .part_file .write (b'\n ' )
103
123
104
124
file_saver .__exit__ (None , None , None )
105
- new_data_file = context .create_data_file (dst_path , file_type = "index" )
125
+ new_data_file = context .create_data_file (dst_path , file_type = "index" , tags = data_file . tags )
106
126
return [data_file , new_data_file ]
107
127
108
128
@@ -120,7 +140,7 @@ def process(self, data_file, context):
120
140
for data , line in file_reader .readlines ():
121
141
f .write (json .dumps (common .flatten_dict (data )).encode ('utf-8' ))
122
142
f .write (b'\n ' )
123
- return context .create_data_file (file_path = dst_path )
143
+ return context .create_data_file (file_path = dst_path , tags = data_file . tags )
124
144
125
145
126
146
class FormatFunction (FunctionMultiMixin , Function ):
@@ -184,7 +204,7 @@ def process(self, data_file, context):
184
204
pk_values .add (pk_value )
185
205
f .write (json .dumps (data ).encode ('utf-8' ))
186
206
f .write (b'\n ' )
187
- return data_file , context .create_data_file (file_path = dst_path )
207
+ return data_file , context .create_data_file (file_path = dst_path , tags = data_file . tags )
188
208
189
209
190
210
class FilterFunction (FunctionMultiMixin , Function ):
@@ -200,7 +220,7 @@ def process(self, data_file, context):
200
220
tags = rule_config .get ('tags' )
201
221
rule = rule_config .get ('rule' , "False" )
202
222
203
- dst_path = os .path .join (context .temp_dir , f'{ data_file .name } -filter-{ tags if tags else "default" } .jsonl' )
223
+ dst_path = os .path .join (context .temp_dir , f'{ data_file .name } -filter-{ "_" . join ( list ( tags . values ())) if tags else "default" } .jsonl' )
204
224
file_saver = AtomicSaver (dst_path )
205
225
file_saver .setup ()
206
226
0 commit comments