Skip to content

Commit

Permalink
feat(tree): add file_wildcard to filter files (bytedance#981)
Browse files Browse the repository at this point in the history
  • Loading branch information
gejielun authored Jun 7, 2022
1 parent ac7e381 commit 578a3ea
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 20 deletions.
4 changes: 3 additions & 1 deletion deploy/scripts/trainer/run_tree_worker.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ data_path=$(normalize_env_to_args "--data-path" "$DATA_PATH")
validation_data_path=$(normalize_env_to_args "--validation-data-path" "$VALIDATION_DATA_PATH")
no_data=$(normalize_env_to_args "--no-data" "$NO_DATA")
file_ext=$(normalize_env_to_args "--file-ext" "$FILE_EXT")
file_wildcard=$(normalize_env_to_args "--file-wildcard" "$FILE_WILDCARD")
file_type=$(normalize_env_to_args "--file-type" "$FILE_TYPE")
load_model_path=$(normalize_env_to_args "--load-model-path" "$LOAD_MODEL_PATH")
verbosity=$(normalize_env_to_args "--verbosity" "$VERBOSITY")
Expand Down Expand Up @@ -76,4 +77,5 @@ python -m fedlearner.model.tree.trainer \
$max_depth $l2_regularization $max_bins \
$num_parallel $verify_example_ids $ignore_fields \
$cat_fields $use_streaming $send_scores_to_follower \
$send_metrics_to_follower $enable_packing $label_field
$send_metrics_to_follower $enable_packing $label_field \
$file_wildcard
12 changes: 12 additions & 0 deletions example/tree_model/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ python -m fedlearner.model.tree.trainer follower \
--peer-addr=localhost:50051 \
--verify-example-ids=true \
--file-ext=.tfrecord \
--file-wildcard=*tfrecord \
--file-type=tfrecord \
--data-path=data/follower_train.tfrecord \
--validation-data-path=data/follower_test \
Expand All @@ -27,6 +28,7 @@ python -m fedlearner.model.tree.trainer leader \
--peer-addr=localhost:50052 \
--verify-example-ids=true \
--file-ext=.tfrecord \
--file-wildcard=*tfrecord \
--file-type=tfrecord \
--data-path=data/leader_train.tfrecord \
--validation-data-path=data/leader_test \
Expand All @@ -44,6 +46,7 @@ python -m fedlearner.model.tree.trainer leader \
--verify-example-ids=true \
--file-type=tfrecord \
--file-ext=.tfrecord \
--file-wildcard=*tfrecord \
--data-path=data/leader_test/ \
--cat-fields=f00001 \
--load-model-path=exp/leader_checkpoints/checkpoint-0004.proto \
Expand All @@ -57,6 +60,7 @@ python -m fedlearner.model.tree.trainer follower \
--verify-example-ids=true \
--file-type=tfrecord \
--file-ext=.tfrecord \
--file-wildcard=*tfrecord \
--data-path=data/follower_test/ \
--cat-fields=f00001 \
--load-model-path=exp/follower_checkpoints/checkpoint-0004.proto \
Expand All @@ -73,7 +77,9 @@ python -m fedlearner.model.tree.trainer follower \
--verbosity=1 \
--local-addr=localhost:50052 \
--peer-addr=localhost:50051 \
--file-ext=.csv \
--file-type=csv \
--file-wildcard=*csv \
--data-path=data/follower_train.csv \
--cat-fields=f00001 \
--checkpoint-path=exp/follower_checkpoints \
Expand All @@ -83,7 +89,9 @@ python -m fedlearner.model.tree.trainer leader \
--verbosity=1 \
--local-addr=localhost:50051 \
--peer-addr=localhost:50052 \
--file-ext=.csv \
--file-type=csv \
--file-wildcard=*csv \
--data-path=data/leader_train.csv \
--ignore-fields=f00000,f00001 \
--checkpoint-path=exp/leader_checkpoints \
Expand All @@ -96,6 +104,8 @@ python -m fedlearner.model.tree.trainer follower \
--local-addr=localhost:50052 \
--peer-addr=localhost:50051 \
--mode=test \
--file-ext=.csv \
--file-wildcard=*csv \
--file-type=csv \
--data-path=data/follower_test/ \
--cat-fields=f00001 \
Expand All @@ -107,6 +117,8 @@ python -m fedlearner.model.tree.trainer leader \
--local-addr=localhost:50051 \
--peer-addr=localhost:50052 \
--mode=test \
--file-ext=.csv \
--file-wildcard=*csv \
--no-data=true \
--load-model-path=exp/leader_checkpoints/checkpoint-0004.proto \
--output-path=exp/leader_test_output
Expand Down
53 changes: 34 additions & 19 deletions fedlearner/model/tree/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import argparse
import traceback
import itertools
from fnmatch import fnmatch
from typing import List
import numpy as np

import tensorflow.compat.v1 as tf
Expand Down Expand Up @@ -61,8 +63,12 @@ def create_argument_parser():
parser.add_argument('--no-data', type=str_as_bool,
default=False, const=True, nargs='?',
help='Run prediction without data.')
parser.add_argument('--file-ext', type=str, default='.csv',
help='File extension to use')
parser.add_argument('--file-ext', type=str, default='',
help='File extension to use including .' \
'for example: .csv')
# TODO(gezhengqiang): delete file_ext
parser.add_argument('--file-wildcard', type=str, default='',
help='the wildcard filter for the file')
parser.add_argument('--file-type', type=str, default='csv',
help='input file type: csv or tfrecord')
parser.add_argument('--load-model-path',
Expand Down Expand Up @@ -183,6 +189,21 @@ def extract_field(field_names, field_name, required):
return None


def filter_files(path: str, file_ext: str, file_wildcard: str) -> List[str]:
files = []
for dirname, _, filenames in tf.io.gfile.walk(path):
for filename in filenames:
_, ext = os.path.splitext(filename)
subdirname = os.path.join(path, os.path.relpath(dirname, path))
fpath = os.path.join(subdirname, filename)
if file_ext and ext != file_ext:
continue
if file_wildcard and not fnmatch(fpath, file_wildcard):
continue
files.append(fpath)
return files


def read_data(file_type, filename, require_example_ids, require_labels,
ignore_fields, cat_fields, label_field):
logging.debug('Reading data file from %s', filename)
Expand Down Expand Up @@ -242,22 +263,15 @@ def to_float(x):
labels, example_ids, raw_ids


def read_data_dir(file_ext, file_type, path, require_example_ids,
require_labels, ignore_fields, cat_fields, label_field):
def read_data_dir(file_ext: str, file_wildcard: str, file_type: str, path: str,
require_example_ids: bool, require_labels: bool,
ignore_fields: str, cat_fields: str, label_field: str):
if not tf.io.gfile.isdir(path):
return read_data(
file_type, path, require_example_ids,
require_labels, ignore_fields, cat_fields, label_field)

files = []
for dirname, _, filenames in tf.io.gfile.walk(path):
for filename in filenames:
_, ext = os.path.splitext(filename)
if file_ext and ext != file_ext:
continue
subdirname = os.path.join(path, os.path.relpath(dirname, path))
files.append(os.path.join(subdirname, filename))

files = filter_files(path, file_ext, file_wildcard)
files.sort()
features = None
for fullname in files:
Expand Down Expand Up @@ -299,17 +313,18 @@ def read_data_dir(file_ext, file_type, path, require_example_ids,

def train(args, booster):
X, cat_X, X_names, cat_X_names, y, example_ids, _ = read_data_dir(
args.file_ext, args.file_type, args.data_path, args.verify_example_ids,
args.role != 'follower', args.ignore_fields, args.cat_fields,
args.label_field)
args.file_ext, args.file_wildcard, args.file_type, args.data_path,
args.verify_example_ids, args.role != 'follower', args.ignore_fields,
args.cat_fields, args.label_field)

if args.validation_data_path:
val_X, val_cat_X, val_X_names, val_cat_X_names, val_y, \
val_example_ids, _ = \
read_data_dir(
args.file_ext, args.file_type, args.validation_data_path,
args.verify_example_ids, args.role != 'follower',
args.ignore_fields, args.cat_fields, args.label_field)
args.file_ext, args.file_wildcard, args.file_type,
args.validation_data_path, args.verify_example_ids,
args.role != 'follower', args.ignore_fields,
args.cat_fields, args.label_field)
assert X_names == val_X_names, \
"Train data and validation data must have same features"
assert cat_X_names == val_cat_X_names, \
Expand Down
36 changes: 36 additions & 0 deletions test/tree_model/test_filter_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import tempfile
import unittest
from pathlib import Path
from fedlearner.model.tree.trainer import filter_files

class TestFilterFiles(unittest.TestCase):

def test_filter_files(self):
path = tempfile.mkdtemp()
path = Path(path, 'test').resolve()
path.mkdir()
path.joinpath('test1').mkdir()
path.joinpath('test2').mkdir()
path.joinpath('3.csv').touch()
path.joinpath('3.tfrecord').touch()
path.joinpath('test1').joinpath('1.csv').touch()
path.joinpath('test1').joinpath('2.tfrecord').touch()
path.joinpath('test2').joinpath('2.csv').touch()
path.joinpath('test2').joinpath('1.tfrecord').touch()
path.joinpath('test1/test').mkdir()
path.joinpath('test1/test').joinpath('4.csv').touch()
path.joinpath('test2/test').mkdir()
path.joinpath('test2/test').joinpath('4.tfrecord').touch()

files = filter_files(path, '.csv', '')
self.assertEqual(len(files), 4)
files = filter_files(path, '', '*tfr*')
self.assertEqual(len(files), 4)
files = filter_files(path, '', '')
self.assertEqual(len(files), 8)
files = filter_files(path, '.csv', '*1.*')
self.assertEqual(len(files), 1)


if __name__ == '__main__':
unittest.main()

0 comments on commit 578a3ea

Please sign in to comment.