Skip to content

Commit

Permalink
Don't copy recordbatches in memory during a table deepcopy (huggingfa…
Browse files Browse the repository at this point in the history
…ce#2291)

* don't copy the recordbatches in memory during a table deecopy

* update tests
  • Loading branch information
lhoestq authored Apr 29, 2021
1 parent 02a27b7 commit 0aeb7f5
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/datasets/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class IndexedTableMixin:
def __init__(self, table: pa.Table):
self._schema = table.schema
self._batches = table.to_batches()
self._offsets = np.cumsum([0] + [len(b) for b in self._batches])
self._offsets: np.ndarray = np.cumsum([0] + [len(b) for b in self._batches])

def fast_gather(self, indices: Union[List[int], np.ndarray]) -> pa.Table:
"""
Expand Down Expand Up @@ -158,6 +158,8 @@ def __deepcopy__(self, memo: dict):
# moreover calling deepcopy on a pyarrow table seems to make pa.total_allocated_bytes() decrease for some reason
# by adding it to the memo, self.table won't be copied
memo[id(self.table)] = self.table
# same for the recordbatches used by the index
memo[id(self._batches)] = list(self._batches)
return _deepcopy(self, memo)

def __getstate__(self):
Expand Down
111 changes: 111 additions & 0 deletions tests/test_table.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import pickle
from typing import List, Union

Expand Down Expand Up @@ -53,6 +54,20 @@ def mixed_in_memory_and_memory_mapped_blocks(in_memory_blocks, memory_mapped_blo
return in_memory_blocks[:1] + memory_mapped_blocks[1:]


def assert_deepcopy_without_bringing_data_in_memory(table: MemoryMappedTable):
with assert_arrow_memory_doesnt_increase():
copied_table = copy.deepcopy(table)
assert isinstance(copied_table, MemoryMappedTable)
assert copied_table.table == table.table


def assert_deepcopy_does_bring_data_in_memory(table: MemoryMappedTable):
with assert_arrow_memory_increases():
copied_table = copy.deepcopy(table)
assert isinstance(copied_table, MemoryMappedTable)
assert copied_table.table == table.table


def assert_pickle_without_bringing_data_in_memory(table: MemoryMappedTable):
with assert_arrow_memory_doesnt_increase():
pickled_table = pickle.dumps(table)
Expand All @@ -69,6 +84,12 @@ def assert_pickle_does_bring_data_in_memory(table: MemoryMappedTable):
assert unpickled_table.table == table.table


def assert_index_attributes_equal(table: Table, other: Table):
assert table._batches == other._batches
np.testing.assert_array_equal(table._offsets, other._offsets)
assert table._schema == other._schema


def test_inject_arrow_table_documentation(in_memory_pa_table):
method = pa.Table.slice

Expand Down Expand Up @@ -227,6 +248,24 @@ def test_in_memory_table_from_batches(in_memory_pa_table):
assert isinstance(table, InMemoryTable)


def test_in_memory_table_deepcopy(in_memory_pa_table):
table = InMemoryTable(in_memory_pa_table)
copied_table = copy.deepcopy(table)
assert table.table == copied_table.table
assert_index_attributes_equal(table, copied_table)
# deepcopy must return the exact same arrow objects since they are immutable
assert table.table is copied_table.table
assert all(batch1 is batch2 for batch1, batch2 in zip(table._batches, copied_table._batches))


def test_in_memory_table_pickle(in_memory_pa_table):
table = InMemoryTable(in_memory_pa_table)
pickled_table = pickle.dumps(table)
unpickled_table = pickle.loads(pickled_table)
assert unpickled_table.table == table.table
assert_index_attributes_equal(table, unpickled_table)


@slow
def test_in_memory_table_pickle_big_table():
big_table_4GB = InMemoryTable.from_pydict({"col": [0] * ((4 * 8 << 30) // 64)})
Expand Down Expand Up @@ -325,6 +364,7 @@ def test_memory_mapped_table_init(arrow_file, in_memory_pa_table):
table = MemoryMappedTable(_memory_mapped_arrow_table_from_file(arrow_file), arrow_file)
assert table.table == in_memory_pa_table
assert isinstance(table, MemoryMappedTable)
assert_deepcopy_without_bringing_data_in_memory(table)
assert_pickle_without_bringing_data_in_memory(table)


Expand All @@ -333,6 +373,7 @@ def test_memory_mapped_table_from_file(arrow_file, in_memory_pa_table):
table = MemoryMappedTable.from_file(arrow_file)
assert table.table == in_memory_pa_table
assert isinstance(table, MemoryMappedTable)
assert_deepcopy_without_bringing_data_in_memory(table)
assert_pickle_without_bringing_data_in_memory(table)


Expand All @@ -344,12 +385,34 @@ def test_memory_mapped_table_from_file_with_replay(arrow_file, in_memory_pa_tabl
for method, args, kwargs in replays:
in_memory_pa_table = getattr(in_memory_pa_table, method)(*args, **kwargs)
assert table.table == in_memory_pa_table
assert_deepcopy_without_bringing_data_in_memory(table)
assert_pickle_without_bringing_data_in_memory(table)


def test_memory_mapped_table_deepcopy(arrow_file):
table = MemoryMappedTable.from_file(arrow_file)
copied_table = copy.deepcopy(table)
assert table.table == copied_table.table
assert table.path == copied_table.path
assert_index_attributes_equal(table, copied_table)
# deepcopy must return the exact same arrow objects since they are immutable
assert table.table is copied_table.table
assert all(batch1 is batch2 for batch1, batch2 in zip(table._batches, copied_table._batches))


def test_memory_mapped_table_pickle(arrow_file):
table = MemoryMappedTable.from_file(arrow_file)
pickled_table = pickle.dumps(table)
unpickled_table = pickle.loads(pickled_table)
assert unpickled_table.table == table.table
assert unpickled_table.path == table.path
assert_index_attributes_equal(table, unpickled_table)


def test_memory_mapped_table_pickle_doesnt_fill_memory(arrow_file):
with assert_arrow_memory_doesnt_increase():
table = MemoryMappedTable.from_file(arrow_file)
assert_deepcopy_without_bringing_data_in_memory(table)
assert_pickle_without_bringing_data_in_memory(table)


Expand All @@ -359,6 +422,7 @@ def test_memory_mapped_table_pickle_applies_replay(arrow_file):
table = MemoryMappedTable.from_file(arrow_file, replays=replays)
assert isinstance(table, MemoryMappedTable)
assert table.replays == replays
assert_deepcopy_without_bringing_data_in_memory(table)
assert_pickle_without_bringing_data_in_memory(table)


Expand All @@ -367,6 +431,7 @@ def test_memory_mapped_table_slice(arrow_file, in_memory_pa_table):
assert table.table == in_memory_pa_table.slice(1, 2)
assert isinstance(table, MemoryMappedTable)
assert table.replays == [("slice", (1, 2), {})]
assert_deepcopy_without_bringing_data_in_memory(table)
assert_pickle_without_bringing_data_in_memory(table)


Expand All @@ -376,6 +441,7 @@ def test_memory_mapped_table_filter(arrow_file, in_memory_pa_table):
assert table.table == in_memory_pa_table.filter(mask)
assert isinstance(table, MemoryMappedTable)
assert table.replays == [("filter", (mask,), {})]
assert_deepcopy_without_bringing_data_in_memory(table)
# filter DOES increase memory
# assert_pickle_without_bringing_data_in_memory(table)
assert_pickle_does_bring_data_in_memory(table)
Expand All @@ -386,6 +452,7 @@ def test_memory_mapped_table_flatten(arrow_file, in_memory_pa_table):
assert table.table == in_memory_pa_table.flatten()
assert isinstance(table, MemoryMappedTable)
assert table.replays == [("flatten", tuple(), {})]
assert_deepcopy_without_bringing_data_in_memory(table)
assert_pickle_without_bringing_data_in_memory(table)


Expand All @@ -394,6 +461,7 @@ def test_memory_mapped_table_combine_chunks(arrow_file, in_memory_pa_table):
assert table.table == in_memory_pa_table.combine_chunks()
assert isinstance(table, MemoryMappedTable)
assert table.replays == [("combine_chunks", tuple(), {})]
assert_deepcopy_without_bringing_data_in_memory(table)
assert_pickle_without_bringing_data_in_memory(table)


Expand All @@ -409,6 +477,7 @@ def test_memory_mapped_table_cast(arrow_file, in_memory_pa_table):
assert table.table == in_memory_pa_table.cast(schema)
assert isinstance(table, MemoryMappedTable)
assert table.replays == [("cast", (schema,), {})]
assert_deepcopy_without_bringing_data_in_memory(table)
# cast DOES increase memory when converting integers precision for example
# assert_pickle_without_bringing_data_in_memory(table)
assert_pickle_does_bring_data_in_memory(table)
Expand All @@ -422,6 +491,7 @@ def test_memory_mapped_table_add_column(arrow_file, in_memory_pa_table):
assert table.table == in_memory_pa_table.add_column(i, field_, column)
assert isinstance(table, MemoryMappedTable)
assert table.replays == [("add_column", (i, field_, column), {})]
assert_deepcopy_without_bringing_data_in_memory(table)
assert_pickle_without_bringing_data_in_memory(table)


Expand All @@ -432,6 +502,7 @@ def test_memory_mapped_table_append_column(arrow_file, in_memory_pa_table):
assert table.table == in_memory_pa_table.append_column(field_, column)
assert isinstance(table, MemoryMappedTable)
assert table.replays == [("append_column", (field_, column), {})]
assert_deepcopy_without_bringing_data_in_memory(table)
assert_pickle_without_bringing_data_in_memory(table)


Expand All @@ -440,6 +511,7 @@ def test_memory_mapped_table_remove_column(arrow_file, in_memory_pa_table):
assert table.table == in_memory_pa_table.remove_column(0)
assert isinstance(table, MemoryMappedTable)
assert table.replays == [("remove_column", (0,), {})]
assert_deepcopy_without_bringing_data_in_memory(table)
assert_pickle_without_bringing_data_in_memory(table)


Expand All @@ -451,6 +523,7 @@ def test_memory_mapped_table_set_column(arrow_file, in_memory_pa_table):
assert table.table == in_memory_pa_table.set_column(i, field_, column)
assert isinstance(table, MemoryMappedTable)
assert table.replays == [("set_column", (i, field_, column), {})]
assert_deepcopy_without_bringing_data_in_memory(table)
assert_pickle_without_bringing_data_in_memory(table)


Expand All @@ -461,6 +534,7 @@ def test_memory_mapped_table_rename_columns(arrow_file, in_memory_pa_table):
assert table.table == in_memory_pa_table.rename_columns(names)
assert isinstance(table, MemoryMappedTable)
assert table.replays == [("rename_columns", (names,), {})]
assert_deepcopy_without_bringing_data_in_memory(table)
assert_pickle_without_bringing_data_in_memory(table)


Expand All @@ -470,6 +544,7 @@ def test_memory_mapped_table_drop(arrow_file, in_memory_pa_table):
assert table.table == in_memory_pa_table.drop(names)
assert isinstance(table, MemoryMappedTable)
assert table.replays == [("drop", (names,), {})]
assert_deepcopy_without_bringing_data_in_memory(table)
assert_pickle_without_bringing_data_in_memory(table)


Expand Down Expand Up @@ -556,6 +631,42 @@ def test_concatenation_table_from_tables(axis, in_memory_pa_table, arrow_file):
assert isinstance(table.blocks[1][0] if axis == 0 else table.blocks[0][1], MemoryMappedTable)


@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
def test_concatenation_table_deepcopy(
blocks_type, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
):
blocks = {
"in_memory": in_memory_blocks,
"memory_mapped": memory_mapped_blocks,
"mixed": mixed_in_memory_and_memory_mapped_blocks,
}[blocks_type]
table = ConcatenationTable.from_blocks(blocks)
copied_table = copy.deepcopy(table)
assert table.table == copied_table.table
assert table.blocks == copied_table.blocks
assert_index_attributes_equal(table, copied_table)
# deepcopy must return the exact same arrow objects since they are immutable
assert table.table is copied_table.table
assert all(batch1 is batch2 for batch1, batch2 in zip(table._batches, copied_table._batches))


@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
def test_concatenation_table_pickle(
blocks_type, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
):
blocks = {
"in_memory": in_memory_blocks,
"memory_mapped": memory_mapped_blocks,
"mixed": mixed_in_memory_and_memory_mapped_blocks,
}[blocks_type]
table = ConcatenationTable.from_blocks(blocks)
pickled_table = pickle.dumps(table)
unpickled_table = pickle.loads(pickled_table)
assert unpickled_table.table == table.table
assert unpickled_table.blocks == table.blocks
assert_index_attributes_equal(table, unpickled_table)


@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"])
def test_concatenation_table_slice(
blocks_type, in_memory_pa_table, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks
Expand Down

0 comments on commit 0aeb7f5

Please sign in to comment.