Skip to content

Commit

Permalink
not use the csv dict reader/writer since the schema cost too much mem…
Browse files Browse the repository at this point in the history
…ory (bytedance#311)

* make the heap memory stats singleton

* make the csv reader/writer not use the csv_dict_reader/csv_dict_writer

* make the lint pass

* make the unnitest pass

* add the tracback print_stack for os_.exit

* make the tf 2 csv maintain the sequence of field

Co-authored-by: fangchenliaohui <[email protected]>
  • Loading branch information
fclh1991 and fangchenliaohui authored Sep 2, 2020
1 parent 70334cc commit 4b4539a
Show file tree
Hide file tree
Showing 22 changed files with 219 additions and 111 deletions.
12 changes: 5 additions & 7 deletions fedlearner/data_join/cmd/generate_csv_raw_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import argparse
import logging
import os
from collections import OrderedDict
from cityhash import CityHash32 # pylint: disable=no-name-in-module

import tensorflow_io # pylint: disable=unused-import
Expand All @@ -39,16 +38,15 @@ def generate_input_csv(base_dir, start_id, end_id, partition_num):
csv_writers = [SortRunMergerWriter(base_dir, 0,
partition_id, writer_options)
for partition_id in range(partition_num)]
field_keys = ['raw_id', 'feat_0', 'feat_1', 'feat_2']
for idx in range(start_id, end_id):
if idx % 262144 == 0:
logging.info("Process at index %d", idx)
partition_id = CityHash32(str(idx)) % partition_num
raw = OrderedDict()
raw['raw_id'] = str(idx)
raw['feat_0'] = str((partition_id << 30) + 0) + str(idx)
raw['feat_1'] = str((partition_id << 30) + 1) + str(idx)
raw['feat_2'] = str((partition_id << 30) + 2) + str(idx)
csv_writers[partition_id].append(CsvItem(raw))
field_vals = [str(idx), str((partition_id << 30) + 0) + str(idx),
str((partition_id << 30) + 1) + str(idx),
str((partition_id << 30) + 2) + str(idx)]
csv_writers[partition_id].append(CsvItem(field_keys, field_vals))
for partition_id, csv_writer in enumerate(csv_writers):
fpaths = csv_writer.finish()
logging.info("partition %d dump %d files", partition_id, len(fpaths))
Expand Down
2 changes: 2 additions & 0 deletions fedlearner/data_join/cmd/rsa_psi_preprocessor_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import argparse
import logging
import os
import traceback

import tensorflow_io # pylint: disable=unused-import
from tensorflow.compat.v1 import gfile
Expand Down Expand Up @@ -123,6 +124,7 @@
offload_processor_number = int(os.environ.get('CPU_LIMIT', '2')) - 1
if offload_processor_number < 1:
logging.fatal("we should at least retain 1 cpu for compute task")
traceback.print_stack()
os._exit(-1) # pylint: disable=protected-access
preprocessor_options = dj_pb.RsaPsiPreProcessorOptions(
preprocessor_name=args.preprocessor_name,
Expand Down
65 changes: 46 additions & 19 deletions fedlearner/data_join/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import threading
import time
from contextlib import contextmanager
from collections import OrderedDict

from guppy import hpy

Expand Down Expand Up @@ -143,10 +142,12 @@ def raw_data_pub_etcd_key(pub_base_dir, partition_id, process_index):
'{:08}{}'.format(process_index, RawDataPubSuffix))

_valid_basic_feature_type = (int, str, float)
def convert_dict_to_tf_example(src_dict):
assert isinstance(src_dict, dict)
def convert_csv_record_to_tf_example(field_keys, field_vals):
assert isinstance(field_keys, list) and \
isinstance(field_vals, list) and \
len(field_keys) == len(field_vals)
tf_feature = {}
for key, feature in src_dict.items():
for key, feature in zip(field_keys, field_vals):
if not isinstance(key, str):
raise RuntimeError('the key {}({}) of dict must a '\
'string'.format(key, type(key)))
Expand Down Expand Up @@ -188,11 +189,13 @@ def convert_dict_to_tf_example(src_dict):
float_list=tf.train.FloatList(value=value))
return tf.train.Example(features=tf.train.Features(feature=tf_feature))

def convert_tf_example_to_dict(src_tf_example):
def convert_tf_example_to_csv_record(src_tf_example):
assert isinstance(src_tf_example, tf.train.Example)
dst_dict = OrderedDict()
field_keys, field_vals = [], []
tf_feature = src_tf_example.features.feature
for key, feat in tf_feature.items():
sorted_keys = sorted(tf_feature.keys())
for key in sorted_keys:
feat = tf_feature[key]
csv_val = None
if feat.HasField('int64_list'):
csv_val = [item for item in feat.int64_list.value] # pylint: disable=unnecessary-comprehension
Expand All @@ -203,8 +206,16 @@ def convert_tf_example_to_dict(src_tf_example):
else:
assert False, "feat type must in int64, byte, float"
assert isinstance(csv_val, list)
dst_dict[key] = csv_val[0] if len(csv_val) == 1 else csv_val
return dst_dict
insert_idx = len(field_keys)
if key == 'example_id':
insert_idx = 0
elif key == 'raw_id':
insert_idx = 1 if (len(field_keys) > 0 and \
field_keys[0] == 'example_id') else 0
rval = csv_val[0] if len(csv_val) == 1 else csv_val
field_keys.insert(insert_idx, key)
field_vals.insert(insert_idx, rval)
return field_keys, field_vals

def int2bytes(digit, byte_len, byteorder='little'):
return int(digit).to_bytes(byte_len, byteorder)
Expand Down Expand Up @@ -238,7 +249,18 @@ def data_source_data_block_dir(data_source):
def data_source_example_dumped_dir(data_source):
return os.path.join(data_source.output_base_dir, 'example_dump')

class _MemUsageProxy(object):

class Singleton(type):
_instances = {}
_lck = threading.Lock()
def __call__(cls, *args, **kwargs):
with cls._lck:
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args,
**kwargs)
return cls._instances[cls]

class _MemUsageProxy(object, metaclass=Singleton):
def __init__(self):
self._lock = threading.Lock()
self._mem_limit = int(os.environ.get('MEM_LIMIT', '17179869184'))
Expand All @@ -262,14 +284,15 @@ def get_heap_mem_usage(self):

def _update_rss_mem_usage(self):
with self._lock:
if time.time() - self._rss_updated_tm >= 0.5:
if time.time() - self._rss_updated_tm >= 0.25:
self._rss_mem_usage = psutil.Process().memory_info().rss
self._rss_updated_tm = time.time()
return self._rss_mem_usage

_mem_usage_proxy = _MemUsageProxy()
def _get_mem_usage_proxy():
return _MemUsageProxy()

class HeapMemStats(object):
class _HeapMemStats(object, metaclass=Singleton):
class StatsRecord(object):
def __init__(self, potential_mem_incr, stats_expiration_time):
self._lock = threading.Lock()
Expand All @@ -287,7 +310,8 @@ def stats_expiration(self):
def update_stats(self):
with self._lock:
if self.stats_expiration():
self._heap_mem_usage = _mem_usage_proxy.get_heap_mem_usage()
self._heap_mem_usage = \
_get_mem_usage_proxy().get_heap_mem_usage()
self._stats_ts = time.time()

def get_heap_mem_usage(self):
Expand All @@ -311,17 +335,17 @@ def CheckOomRisk(self, stats_key,
if inner_key not in self._stats_map:
inner_key = self._gen_inner_stats_key(stats_key)
self._stats_map[inner_key] = \
HeapMemStats.StatsRecord(potential_mem_incr,
self._stats_expiration_time)
_HeapMemStats.StatsRecord(potential_mem_incr,
self._stats_expiration_time)
sr = self._stats_map[inner_key]
if not sr.stats_expiration():
return _mem_usage_proxy.check_heap_mem_water_level(
return _get_mem_usage_proxy().check_heap_mem_water_level(
sr.get_heap_mem_usage(),
water_level_percent
)
assert sr is not None
sr.update_stats()
return _mem_usage_proxy.check_heap_mem_water_level(
return _get_mem_usage_proxy().check_heap_mem_water_level(
sr.get_heap_mem_usage(), water_level_percent
)

Expand All @@ -331,7 +355,7 @@ def _gen_inner_stats_key(self, stats_key):
def _need_heap_stats(self, stats_key):
with self._lock:
if self._stats_granular <= 0 and \
_mem_usage_proxy.check_rss_mem_water_level(0.5):
_get_mem_usage_proxy().check_rss_mem_water_level(0.5):
self._stats_granular = stats_key // 16
self._stats_start_key = stats_key // 2
if self._stats_granular <= 0:
Expand All @@ -340,3 +364,6 @@ def _need_heap_stats(self, stats_key):
self._stats_granular)
return self._stats_granular > 0 and \
stats_key >= self._stats_start_key

def get_heap_mem_stats(stats_expiration_time):
return _HeapMemStats(stats_expiration_time)
31 changes: 21 additions & 10 deletions fedlearner/data_join/csv_dict_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

import csv
import io
import os
import logging
import traceback

import tensorflow_io # pylint: disable=unused-import
from tensorflow.compat.v1 import gfile
Expand All @@ -27,18 +30,26 @@ def __init__(self, fpath):
self._file_hanlde = gfile.Open(fpath, 'w+')
self._buffer_handle = io.StringIO()
self._csv_writer = None
self._csv_headers = None

def write(self, raw):
assert isinstance(raw, dict)
if len(raw) == 0:
return
def write(self, fields):
assert isinstance(fields, tuple)
field_keys, field_vals = fields[0], fields[1]
assert isinstance(field_keys, list) and \
isinstance(field_vals, list) and \
len(field_keys) == len(field_vals)
if self._csv_writer is None:
self._csv_writer = csv.DictWriter(
self._buffer_handle,
fieldnames=raw.keys()
)
self._csv_writer.writeheader()
self._csv_writer.writerow(raw)
self._csv_writer = csv.writer(self._buffer_handle)
self._csv_headers = field_keys
self._csv_writer.writerow(field_keys)
else:
assert self._csv_headers is not None
if self._csv_headers != field_keys:
logging.fatal("the schema of csv item is %s, mismatch with "\
"previous %s", self._csv_headers, field_keys)
traceback.print_stack()
os._exit(-1) # pylint: disable=protected-access
self._csv_writer.writerow(field_vals)
self._write_raw_num += 1
self._flush_buffer(False)

Expand Down
11 changes: 6 additions & 5 deletions fedlearner/data_join/data_block_dumper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import logging
import os
import time
import traceback
from contextlib import contextmanager

from fedlearner.common import metrics
Expand Down Expand Up @@ -157,6 +158,7 @@ def _dump_data_block_by_meta(self, meta):
except StopIteration:
logging.fatal("raw data finished before when seek to %d",
meta.leader_start_index-1)
traceback.print_stack()
os._exit(-1) # pylint: disable=protected-access
match_index = 0
example_num = len(meta.example_ids)
Expand All @@ -170,11 +172,10 @@ def _dump_data_block_by_meta(self, meta):
if index >= meta.leader_end_index:
break
if match_index < example_num:
logging.fatal(
"Data lose corrupt! only match %d/%d example "
"for data block %s",
match_index, example_num, meta.block_id
)
logging.fatal("Data lose corrupt! only match %d/%d example "
"for data block %s",
match_index, example_num, meta.block_id)
traceback.print_stack()
os._exit(-1) # pylint: disable=protected-access
dumped_meta = data_block_builder.finish_data_block(True)
assert dumped_meta == meta, "the generated dumped meta shoud "\
Expand Down
3 changes: 3 additions & 0 deletions fedlearner/data_join/data_block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import logging
import os
import copy
import traceback

import tensorflow.compat.v1 as tf
from google.protobuf import text_format
Expand Down Expand Up @@ -294,6 +295,7 @@ def _make_directory_if_nessary(self):
gfile.MakeDirs(data_block_dir)
if not gfile.IsDirectory(data_block_dir):
logging.fatal("%s should be directory", data_block_dir)
traceback.print_stack()
os._exit(-1) # pylint: disable=protected-access

def _sync_data_block_meta(self, index):
Expand All @@ -306,6 +308,7 @@ def _sync_data_block_meta(self, index):
if meta is None:
logging.fatal("data block index as %d has dumped "\
"but vanish", index)
traceback.print_stack()
os._exit(-1) # pylint: disable=protected-access
self._data_block_meta_cache[index] = meta
return self._data_block_meta_cache[index]
Expand Down
2 changes: 2 additions & 0 deletions fedlearner/data_join/example_id_batch_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import logging
import os
import traceback

from fedlearner.common import data_join_service_pb2 as dj_pb
from fedlearner.data_join.item_batch_seq_processor import \
Expand Down Expand Up @@ -97,6 +98,7 @@ def _make_inner_generator(self, next_index):
logging.fatal("index of raw data visitor for partition "\
"%d is not consecutive, %d != %d",
self._partition_id, index, next_index)
traceback.print_stack()
os._exit(-1) # pylint: disable=protected-access
next_batch.append(item)
next_index += 1
Expand Down
15 changes: 10 additions & 5 deletions fedlearner/data_join/example_id_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import logging
import os
from os import path
import traceback

from google.protobuf import text_format, empty_pb2

Expand All @@ -38,7 +38,7 @@ def encode_example_id_dumped_fname(process_index, start_index):
return '{:06}-{:08}{}'.format(process_index, start_index, DoneFileSuffix)

def decode_index_meta(fpath):
fname = path.basename(fpath)
fname = os.path.basename(fpath)
index_str = fname[:-len(DoneFileSuffix)]
try:
items = index_str.split('-')
Expand All @@ -48,6 +48,7 @@ def decode_index_meta(fpath):
except Exception as e: # pylint: disable=broad-except
logging.fatal("fname %s not satisfied with pattern process_index-"\
"start_index", fname)
traceback.print_stack()
os._exit(-1) # pylint: disable=protected-access
else:
return visitor.IndexMeta(process_index, start_index, fpath)
Expand Down Expand Up @@ -100,8 +101,8 @@ def update_dumped_example_id_anchor(self, index_meta, end_index):
process_index = index_meta.process_index
start_index = index_meta.start_index
fpath = index_meta.fpath
dirname = path.dirname(fpath)
fname = path.basename(fpath)
dirname = os.path.dirname(fpath)
fname = os.path.basename(fpath)
if not gfile.Exists(fpath):
raise ValueError("file {} is not existed".format(fpath))
if dirname != self._example_dumped_dir():
Expand Down Expand Up @@ -165,11 +166,12 @@ def _preload_example_id_meta(self):
if index != index_meta.process_index:
logging.fatal("%s has error process index. expected %d",
index_meta.fpath, index)
traceback.print_stack()
os._exit(-1) # pylint: disable=protected-access
return index_metas

def _decode_index_meta(self, fpath):
fname = path.basename(fpath)
fname = os.path.basename(fpath)
index_str = fname[:-len(DoneFileSuffix)]
try:
items = index_str.split('-')
Expand All @@ -179,6 +181,7 @@ def _decode_index_meta(self, fpath):
except Exception as e: # pylint: disable=broad-except
logging.fatal("fname %s not satisfied with pattern process_index-"\
"start_index", fname)
traceback.print_stack()
os._exit(-1) # pylint: disable=protected-access
else:
return visitor.IndexMeta(process_index, start_index, fpath)
Expand All @@ -195,6 +198,7 @@ def _new_index_meta(self, process_index, start_index):
if not gfile.Exists(fpath):
logging.fatal("%d has been dumpped however %s not "\
"in file system", start_index, fpath)
traceback.print_stack()
os._exit(-1) # pylint: disable=protected-access
return visitor.IndexMeta(process_index, start_index, fpath)
return None
Expand Down Expand Up @@ -230,6 +234,7 @@ def _make_directory_if_nessary(self):
gfile.MakeDirs(example_dumped_dir)
if not gfile.IsDirectory(example_dumped_dir):
logging.fatal("%s should be directory", example_dumped_dir)
traceback.print_stack()
os._exit(-1) # pylint: disable=protected-access

class ExampleIdVisitor(visitor.Visitor):
Expand Down
Loading

0 comments on commit 4b4539a

Please sign in to comment.