Skip to content

Commit

Permalink
Add tfrecord support for tree model (bytedance#98)
Browse files Browse the repository at this point in the history
* Add tfrecord support for tree model

* lint
  • Loading branch information
piiswrong authored Jul 8, 2020
1 parent a0a92a8 commit 8dcaf0a
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 63 deletions.
60 changes: 43 additions & 17 deletions example/tree_model/make_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import tensorflow as tf
from sklearn.datasets import load_iris


def quantize_data(header, dtypes, X):
for i, (h, dtype) in enumerate(zip(header, dtypes)):
if h[0] != 'f' or dtype != np.int32:
Expand All @@ -20,7 +21,18 @@ def quantize_data(header, dtypes, X):

return np.asarray([tuple(i) for i in X], dtype=list(zip(header, dtypes)))

def write_data(filename, X, y, role, verify_example_ids):
def write_tfrecord_data(filename, data, header, dtypes):
fout = tf.io.TFRecordWriter(filename)
for i in range(data.shape[0]):
example = tf.train.Example()
for h, d, x in zip(header, dtypes, data[i]):
if d == np.int32:
example.features.feature[h].int64_list.value.append(x)
else:
example.features.feature[h].float_list.value.append(x)
fout.write(example.SerializeToString())

def write_data(output_type, filename, X, y, role, verify_example_ids):
if role == 'leader':
data = np.concatenate((X[:, :X.shape[1]//2], y), axis=1)
N = data.shape[1]-1
Expand All @@ -44,13 +56,16 @@ def write_data(filename, X, y, role, verify_example_ids):
dtypes = [np.int32] + dtypes

data = quantize_data(header, dtypes, data)
np.savetxt(
filename,
data,
delimiter=',',
header=','.join(header),
fmt=['%d' if i == np.int32 else '%f' for i in dtypes],
comments='')
if output_type == 'tfrecord':
write_tfrecord_data(filename, data, header, dtypes)
else:
np.savetxt(
filename,
data,
delimiter=',',
header=','.join(header),
fmt=['%d' if i == np.int32 else '%f' for i in dtypes],
comments='')

def process_mnist(X, y):
X = X.reshape(X.shape[0], -1)
Expand All @@ -77,33 +92,42 @@ def make_data(args):
os.makedirs('data/local_test')

write_data(
'data/leader_train.csv', x_train, y_train,
args.output_type,
'data/leader_train.%s'%args.output_type, x_train, y_train,
'leader', args.verify_example_ids)
write_data(
'data/follower_train.csv', x_train, y_train,
args.output_type,
'data/follower_train.%s'%args.output_type, x_train, y_train,
'follower', args.verify_example_ids)
write_data(
'data/local_train.csv', x_train, y_train,
args.output_type,
'data/local_train.%s'%args.output_type, x_train, y_train,
'local', False)

write_data(
'data/leader_test/part-0001.csv', x_test, y_test,
args.output_type,
'data/leader_test/part-0001.%s'%args.output_type, x_test, y_test,
'leader', args.verify_example_ids)
write_data(
'data/follower_test/part-0001.csv', x_test, y_test,
args.output_type,
'data/follower_test/part-0001.%s'%args.output_type, x_test, y_test,
'follower', args.verify_example_ids)
write_data(
'data/local_test/part-0001.csv', x_test, y_test,
args.output_type,
'data/local_test/part-0001.%s'%args.output_type, x_test, y_test,
'local', False)

write_data(
'data/leader_test/part-0002.csv', x_test, y_test,
args.output_type,
'data/leader_test/part-0002.%s'%args.output_type, x_test, y_test,
'leader', args.verify_example_ids)
write_data(
'data/follower_test/part-0002.csv', x_test, y_test,
args.output_type,
'data/follower_test/part-0002.%s'%args.output_type, x_test, y_test,
'follower', args.verify_example_ids)
write_data(
'data/local_test/part-0002.csv', x_test, y_test,
args.output_type,
'data/local_test/part-0002.%s'%args.output_type, x_test, y_test,
'local', False)


Expand All @@ -118,4 +142,6 @@ def make_data(args):
'must match between leader and follower')
parser.add_argument('--dataset', type=str, default='mnist',
help='whether to use mnist or iris dataset')
parser.add_argument('--output-type', type=str, default='csv',
help='Output csv or tfrecord')
make_data(parser.parse_args())
21 changes: 15 additions & 6 deletions example/tree_model/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@ cd "$( dirname "${BASH_SOURCE[0]}" )"

rm -rf exp data

python make_data.py --verify-example-ids=1 --dataset=iris
python make_data.py --verify-example-ids=1 --dataset=iris --output-type=tfrecord

python -m fedlearner.model.tree.trainer follower \
--verbosity=1 \
--local-addr=localhost:50052 \
--peer-addr=localhost:50051 \
--verify-example-ids=true \
--data-path=data/follower_train.csv \
--validation-data-path=data/follower_test/part-0001.csv \
--file-type=tfrecord \
--data-path=data/follower_train.tfrecord \
--validation-data-path=data/follower_test/part-0001.tfrecord \
--checkpoint-path=exp/follower_checkpoints \
--cat-fields=f00001 \
--output-path=exp/follower_train_output.output &
Expand All @@ -24,8 +25,9 @@ python -m fedlearner.model.tree.trainer leader \
--local-addr=localhost:50051 \
--peer-addr=localhost:50052 \
--verify-example-ids=true \
--data-path=data/leader_train.csv \
--validation-data-path=data/leader_test/part-0001.csv \
--file-type=tfrecord \
--data-path=data/leader_train.tfrecord \
--validation-data-path=data/leader_test/part-0001.tfrecord \
--checkpoint-path=exp/leader_checkpoints \
--cat-fields=f00001 \
--output-path=exp/leader_train_output.output
Expand All @@ -38,6 +40,8 @@ python -m fedlearner.model.tree.trainer leader \
--peer-addr=localhost:50052 \
--mode=test \
--verify-example-ids=true \
--file-type=tfrecord \
--file-ext=.tfrecord \
--data-path=data/leader_test/ \
--cat-fields=f00001 \
--load-model-path=exp/leader_checkpoints/checkpoint-0004.proto \
Expand All @@ -49,6 +53,8 @@ python -m fedlearner.model.tree.trainer follower \
--peer-addr=localhost:50051 \
--mode=test \
--verify-example-ids=true \
--file-type=tfrecord \
--file-ext=.tfrecord \
--data-path=data/follower_test/ \
--cat-fields=f00001 \
--load-model-path=exp/follower_checkpoints/checkpoint-0004.proto \
Expand All @@ -59,12 +65,13 @@ wait

rm -rf exp data

python make_data.py --dataset=iris --verify-example-ids=1
python make_data.py --dataset=iris --verify-example-ids=1 --output-type=csv

python -m fedlearner.model.tree.trainer follower \
--verbosity=1 \
--local-addr=localhost:50052 \
--peer-addr=localhost:50051 \
--file-type=csv \
--data-path=data/follower_train.csv \
--cat-fields=f00001 \
--checkpoint-path=exp/follower_checkpoints \
Expand All @@ -74,6 +81,7 @@ python -m fedlearner.model.tree.trainer leader \
--verbosity=1 \
--local-addr=localhost:50051 \
--peer-addr=localhost:50052 \
--file-type=csv \
--data-path=data/leader_train.csv \
--ignore-fields=f00000,f00001 \
--checkpoint-path=exp/leader_checkpoints \
Expand All @@ -86,6 +94,7 @@ python -m fedlearner.model.tree.trainer follower \
--local-addr=localhost:50052 \
--peer-addr=localhost:50051 \
--mode=test \
--file-type=csv \
--data-path=data/follower_test/ \
--cat-fields=f00001 \
--load-model-path=exp/follower_checkpoints/checkpoint-0004.proto \
Expand Down
110 changes: 70 additions & 40 deletions fedlearner/model/tree/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import logging
import argparse
import traceback
import itertools
import numpy as np

import tensorflow.compat.v1 as tf
Expand Down Expand Up @@ -60,6 +61,8 @@ def create_argument_parser():
help='Run prediction without data.')
parser.add_argument('--file-ext', type=str, default='.csv',
help='File extension to use')
parser.add_argument('--file-type', type=str, default='csv',
help='input file type: csv or tfrecord')
parser.add_argument('--load-model-path',
type=str,
default=None,
Expand Down Expand Up @@ -130,26 +133,57 @@ def create_argument_parser():

return parser


def parse_tfrecord(record):
example = tf.train.Example()
example.ParseFromString(record)

parsed = {}
for key, value in example.features.feature.items():
kind = value.WhichOneof('kind')
if kind == 'float_list':
assert len(value.float_list.value) == 1, "Invalid tfrecord format"
parsed[key] = value.float_list.value[0]
elif kind == 'int64_list':
assert len(value.int64_list.value) == 1, "Invalid tfrecord format"
parsed[key] = value.int64_list.value[0]
elif kind == 'bytes_list':
assert len(value.bytes_list.value) == 1, "Invalid tfrecord format"
parsed[key] = value.bytes_list.value[0]
else:
raise ValueError("Invalid tfrecord format")

return parsed


def extract_field(field_names, field_name, required):
if field_name in field_names:
return field_names.index(field_name), []
return []

assert not required, \
"Data must contain %s field"%field_name
return None, None
"Field %s is required but missing in data"%field_name
return None

def read_csv_data(filename, require_example_ids, require_labels,
ignore_fields, cat_fields):
logging.debug('Reading csv file from %s', filename)
fin = tf.io.gfile.GFile(filename, 'r')
reader = csv.reader(fin)
field_names = next(reader)

example_id_idx, example_ids = extract_field(
def read_data(file_type, filename, require_example_ids,
require_labels, ignore_fields, cat_fields):
logging.debug('Reading data file from %s', filename)

if file_type == 'tfrecord':
reader = tf.io.tf_record_iterator(filename)
reader, tmp_reader = itertools.tee(reader)
first_line = parse_tfrecord(next(tmp_reader))
field_names = first_line.keys()
else:
fin = tf.io.gfile.GFile(filename, 'r')
reader = csv.DictReader(fin)
field_names = reader.fieldnames

example_ids = extract_field(
field_names, 'example_id', require_example_ids)
raw_id_idx, raw_ids = extract_field(
raw_ids = extract_field(
field_names, 'raw_id', False)
label_idx, labels = extract_field(
labels = extract_field(
field_names, 'label', require_labels)

ignore_fields = set(filter(bool, ignore_fields.strip().split(',')))
Expand All @@ -158,52 +192,48 @@ def read_csv_data(filename, require_example_ids, require_labels,
for name in cat_fields:
assert name in field_names, "cat_field %s missing"%name

cont_columns = [
(i, name) for i, name in enumerate(field_names) \
if name not in ignore_fields and name not in cat_fields]
cont_columns = list(filter(
lambda x: x not in ignore_fields and x not in cat_fields, field_names))
cont_columns.sort(key=lambda x: x[1])
cat_columns = [
(i, name) for i, name in enumerate(field_names)
if name in cat_fields]
cat_columns = list(filter(
lambda x: x in cat_fields and x not in cat_fields, field_names))
cat_columns.sort(key=lambda x: x[1])

features = []
cat_features = []
for line in reader:
if example_id_idx is not None:
example_ids.append(line[example_id_idx])
if raw_id_idx is not None:
raw_ids.append(line[raw_id_idx])
if label_idx is not None:
labels.append(float(line[label_idx]))
features.append([float(line[i]) for i, _ in cont_columns])
cat_features.append([int(line[i]) for i, _ in cat_columns])

fin.close()

feature_names = [name for _, name in cont_columns]
cat_feature_names = [name for _, name in cat_columns]
if file_type == 'tfrecord':
line = parse_tfrecord(line)
if example_ids is not None:
example_ids.append(str(line['example_id']))
if raw_ids is not None:
raw_ids.append(str(line['raw_id']))
if labels is not None:
labels.append(float(line['label']))
features.append([float(line[i]) for i in cont_columns])
cat_features.append([int(line[i]) for i in cat_columns])

features = np.array(features, dtype=np.float)
cat_features = np.array(cat_features, dtype=np.int32)
if labels is not None:
labels = np.asarray(labels, dtype=np.float)

return features, cat_features, feature_names, cat_feature_names, \
return features, cat_features, cont_columns, cat_columns, \
labels, example_ids, raw_ids


def train(args, booster):
X, cat_X, X_names, cat_X_names, y, example_ids, _ = read_csv_data(
args.data_path, args.verify_example_ids,
X, cat_X, X_names, cat_X_names, y, example_ids, _ = read_data(
args.file_type, args.data_path, args.verify_example_ids,
args.role != 'follower', args.ignore_fields, args.cat_fields)

if args.validation_data_path:
val_X, val_cat_X, val_X_names, val_cat_X_names, val_y, \
val_example_ids, _ = \
read_csv_data(
args.validation_data_path, args.verify_example_ids,
args.role != 'follower', args.ignore_fields,
args.cat_fields)
read_data(
args.file_type, args.validation_data_path,
args.verify_example_ids, args.role != 'follower',
args.ignore_fields, args.cat_fields)
assert X_names == val_X_names, \
"Train data and validation data must have same features"
assert cat_X_names == val_cat_X_names, \
Expand Down Expand Up @@ -258,8 +288,8 @@ def test_one_file(args, bridge, booster, data_file, output_file):
X = cat_X = X_names = cat_X_names = y = example_ids = raw_ids = None
else:
X, cat_X, X_names, cat_X_names, y, example_ids, raw_ids = \
read_csv_data(
data_file, args.verify_example_ids,
read_data(
args.file_type, data_file, args.verify_example_ids,
False, args.ignore_fields, args.cat_fields)

pred = booster.batch_predict(
Expand Down

0 comments on commit 8dcaf0a

Please sign in to comment.