From 0aeb7f51ba64585e6d8173e8a15ba9eda84cf402 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> Date: Thu, 29 Apr 2021 18:34:33 +0200 Subject: [PATCH] Don't copy recordbatches in memory during a table deepcopy (#2291) * don't copy the recordbatches in memory during a table deecopy * update tests --- src/datasets/table.py | 4 +- tests/test_table.py | 111 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 114 insertions(+), 1 deletion(-) diff --git a/src/datasets/table.py b/src/datasets/table.py index e9001a79f08..13f9d7046fb 100644 --- a/src/datasets/table.py +++ b/src/datasets/table.py @@ -95,7 +95,7 @@ class IndexedTableMixin: def __init__(self, table: pa.Table): self._schema = table.schema self._batches = table.to_batches() - self._offsets = np.cumsum([0] + [len(b) for b in self._batches]) + self._offsets: np.ndarray = np.cumsum([0] + [len(b) for b in self._batches]) def fast_gather(self, indices: Union[List[int], np.ndarray]) -> pa.Table: """ @@ -158,6 +158,8 @@ def __deepcopy__(self, memo: dict): # moreover calling deepcopy on a pyarrow table seems to make pa.total_allocated_bytes() decrease for some reason # by adding it to the memo, self.table won't be copied memo[id(self.table)] = self.table + # same for the recordbatches used by the index + memo[id(self._batches)] = list(self._batches) return _deepcopy(self, memo) def __getstate__(self): diff --git a/tests/test_table.py b/tests/test_table.py index c734a8c9612..78f5807b339 100644 --- a/tests/test_table.py +++ b/tests/test_table.py @@ -1,3 +1,4 @@ +import copy import pickle from typing import List, Union @@ -53,6 +54,20 @@ def mixed_in_memory_and_memory_mapped_blocks(in_memory_blocks, memory_mapped_blo return in_memory_blocks[:1] + memory_mapped_blocks[1:] +def assert_deepcopy_without_bringing_data_in_memory(table: MemoryMappedTable): + with assert_arrow_memory_doesnt_increase(): + copied_table = copy.deepcopy(table) + assert isinstance(copied_table, MemoryMappedTable) + assert copied_table.table == table.table + + +def assert_deepcopy_does_bring_data_in_memory(table: MemoryMappedTable): + with assert_arrow_memory_increases(): + copied_table = copy.deepcopy(table) + assert isinstance(copied_table, MemoryMappedTable) + assert copied_table.table == table.table + + def assert_pickle_without_bringing_data_in_memory(table: MemoryMappedTable): with assert_arrow_memory_doesnt_increase(): pickled_table = pickle.dumps(table) @@ -69,6 +84,12 @@ def assert_pickle_does_bring_data_in_memory(table: MemoryMappedTable): assert unpickled_table.table == table.table +def assert_index_attributes_equal(table: Table, other: Table): + assert table._batches == other._batches + np.testing.assert_array_equal(table._offsets, other._offsets) + assert table._schema == other._schema + + def test_inject_arrow_table_documentation(in_memory_pa_table): method = pa.Table.slice @@ -227,6 +248,24 @@ def test_in_memory_table_from_batches(in_memory_pa_table): assert isinstance(table, InMemoryTable) +def test_in_memory_table_deepcopy(in_memory_pa_table): + table = InMemoryTable(in_memory_pa_table) + copied_table = copy.deepcopy(table) + assert table.table == copied_table.table + assert_index_attributes_equal(table, copied_table) + # deepcopy must return the exact same arrow objects since they are immutable + assert table.table is copied_table.table + assert all(batch1 is batch2 for batch1, batch2 in zip(table._batches, copied_table._batches)) + + +def test_in_memory_table_pickle(in_memory_pa_table): + table = InMemoryTable(in_memory_pa_table) + pickled_table = pickle.dumps(table) + unpickled_table = pickle.loads(pickled_table) + assert unpickled_table.table == table.table + assert_index_attributes_equal(table, unpickled_table) + + @slow def test_in_memory_table_pickle_big_table(): big_table_4GB = InMemoryTable.from_pydict({"col": [0] * ((4 * 8 << 30) // 64)}) @@ -325,6 +364,7 @@ def test_memory_mapped_table_init(arrow_file, in_memory_pa_table): table = MemoryMappedTable(_memory_mapped_arrow_table_from_file(arrow_file), arrow_file) assert table.table == in_memory_pa_table assert isinstance(table, MemoryMappedTable) + assert_deepcopy_without_bringing_data_in_memory(table) assert_pickle_without_bringing_data_in_memory(table) @@ -333,6 +373,7 @@ def test_memory_mapped_table_from_file(arrow_file, in_memory_pa_table): table = MemoryMappedTable.from_file(arrow_file) assert table.table == in_memory_pa_table assert isinstance(table, MemoryMappedTable) + assert_deepcopy_without_bringing_data_in_memory(table) assert_pickle_without_bringing_data_in_memory(table) @@ -344,12 +385,34 @@ def test_memory_mapped_table_from_file_with_replay(arrow_file, in_memory_pa_tabl for method, args, kwargs in replays: in_memory_pa_table = getattr(in_memory_pa_table, method)(*args, **kwargs) assert table.table == in_memory_pa_table + assert_deepcopy_without_bringing_data_in_memory(table) assert_pickle_without_bringing_data_in_memory(table) +def test_memory_mapped_table_deepcopy(arrow_file): + table = MemoryMappedTable.from_file(arrow_file) + copied_table = copy.deepcopy(table) + assert table.table == copied_table.table + assert table.path == copied_table.path + assert_index_attributes_equal(table, copied_table) + # deepcopy must return the exact same arrow objects since they are immutable + assert table.table is copied_table.table + assert all(batch1 is batch2 for batch1, batch2 in zip(table._batches, copied_table._batches)) + + +def test_memory_mapped_table_pickle(arrow_file): + table = MemoryMappedTable.from_file(arrow_file) + pickled_table = pickle.dumps(table) + unpickled_table = pickle.loads(pickled_table) + assert unpickled_table.table == table.table + assert unpickled_table.path == table.path + assert_index_attributes_equal(table, unpickled_table) + + def test_memory_mapped_table_pickle_doesnt_fill_memory(arrow_file): with assert_arrow_memory_doesnt_increase(): table = MemoryMappedTable.from_file(arrow_file) + assert_deepcopy_without_bringing_data_in_memory(table) assert_pickle_without_bringing_data_in_memory(table) @@ -359,6 +422,7 @@ def test_memory_mapped_table_pickle_applies_replay(arrow_file): table = MemoryMappedTable.from_file(arrow_file, replays=replays) assert isinstance(table, MemoryMappedTable) assert table.replays == replays + assert_deepcopy_without_bringing_data_in_memory(table) assert_pickle_without_bringing_data_in_memory(table) @@ -367,6 +431,7 @@ def test_memory_mapped_table_slice(arrow_file, in_memory_pa_table): assert table.table == in_memory_pa_table.slice(1, 2) assert isinstance(table, MemoryMappedTable) assert table.replays == [("slice", (1, 2), {})] + assert_deepcopy_without_bringing_data_in_memory(table) assert_pickle_without_bringing_data_in_memory(table) @@ -376,6 +441,7 @@ def test_memory_mapped_table_filter(arrow_file, in_memory_pa_table): assert table.table == in_memory_pa_table.filter(mask) assert isinstance(table, MemoryMappedTable) assert table.replays == [("filter", (mask,), {})] + assert_deepcopy_without_bringing_data_in_memory(table) # filter DOES increase memory # assert_pickle_without_bringing_data_in_memory(table) assert_pickle_does_bring_data_in_memory(table) @@ -386,6 +452,7 @@ def test_memory_mapped_table_flatten(arrow_file, in_memory_pa_table): assert table.table == in_memory_pa_table.flatten() assert isinstance(table, MemoryMappedTable) assert table.replays == [("flatten", tuple(), {})] + assert_deepcopy_without_bringing_data_in_memory(table) assert_pickle_without_bringing_data_in_memory(table) @@ -394,6 +461,7 @@ def test_memory_mapped_table_combine_chunks(arrow_file, in_memory_pa_table): assert table.table == in_memory_pa_table.combine_chunks() assert isinstance(table, MemoryMappedTable) assert table.replays == [("combine_chunks", tuple(), {})] + assert_deepcopy_without_bringing_data_in_memory(table) assert_pickle_without_bringing_data_in_memory(table) @@ -409,6 +477,7 @@ def test_memory_mapped_table_cast(arrow_file, in_memory_pa_table): assert table.table == in_memory_pa_table.cast(schema) assert isinstance(table, MemoryMappedTable) assert table.replays == [("cast", (schema,), {})] + assert_deepcopy_without_bringing_data_in_memory(table) # cast DOES increase memory when converting integers precision for example # assert_pickle_without_bringing_data_in_memory(table) assert_pickle_does_bring_data_in_memory(table) @@ -422,6 +491,7 @@ def test_memory_mapped_table_add_column(arrow_file, in_memory_pa_table): assert table.table == in_memory_pa_table.add_column(i, field_, column) assert isinstance(table, MemoryMappedTable) assert table.replays == [("add_column", (i, field_, column), {})] + assert_deepcopy_without_bringing_data_in_memory(table) assert_pickle_without_bringing_data_in_memory(table) @@ -432,6 +502,7 @@ def test_memory_mapped_table_append_column(arrow_file, in_memory_pa_table): assert table.table == in_memory_pa_table.append_column(field_, column) assert isinstance(table, MemoryMappedTable) assert table.replays == [("append_column", (field_, column), {})] + assert_deepcopy_without_bringing_data_in_memory(table) assert_pickle_without_bringing_data_in_memory(table) @@ -440,6 +511,7 @@ def test_memory_mapped_table_remove_column(arrow_file, in_memory_pa_table): assert table.table == in_memory_pa_table.remove_column(0) assert isinstance(table, MemoryMappedTable) assert table.replays == [("remove_column", (0,), {})] + assert_deepcopy_without_bringing_data_in_memory(table) assert_pickle_without_bringing_data_in_memory(table) @@ -451,6 +523,7 @@ def test_memory_mapped_table_set_column(arrow_file, in_memory_pa_table): assert table.table == in_memory_pa_table.set_column(i, field_, column) assert isinstance(table, MemoryMappedTable) assert table.replays == [("set_column", (i, field_, column), {})] + assert_deepcopy_without_bringing_data_in_memory(table) assert_pickle_without_bringing_data_in_memory(table) @@ -461,6 +534,7 @@ def test_memory_mapped_table_rename_columns(arrow_file, in_memory_pa_table): assert table.table == in_memory_pa_table.rename_columns(names) assert isinstance(table, MemoryMappedTable) assert table.replays == [("rename_columns", (names,), {})] + assert_deepcopy_without_bringing_data_in_memory(table) assert_pickle_without_bringing_data_in_memory(table) @@ -470,6 +544,7 @@ def test_memory_mapped_table_drop(arrow_file, in_memory_pa_table): assert table.table == in_memory_pa_table.drop(names) assert isinstance(table, MemoryMappedTable) assert table.replays == [("drop", (names,), {})] + assert_deepcopy_without_bringing_data_in_memory(table) assert_pickle_without_bringing_data_in_memory(table) @@ -556,6 +631,42 @@ def test_concatenation_table_from_tables(axis, in_memory_pa_table, arrow_file): assert isinstance(table.blocks[1][0] if axis == 0 else table.blocks[0][1], MemoryMappedTable) +@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"]) +def test_concatenation_table_deepcopy( + blocks_type, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks +): + blocks = { + "in_memory": in_memory_blocks, + "memory_mapped": memory_mapped_blocks, + "mixed": mixed_in_memory_and_memory_mapped_blocks, + }[blocks_type] + table = ConcatenationTable.from_blocks(blocks) + copied_table = copy.deepcopy(table) + assert table.table == copied_table.table + assert table.blocks == copied_table.blocks + assert_index_attributes_equal(table, copied_table) + # deepcopy must return the exact same arrow objects since they are immutable + assert table.table is copied_table.table + assert all(batch1 is batch2 for batch1, batch2 in zip(table._batches, copied_table._batches)) + + +@pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"]) +def test_concatenation_table_pickle( + blocks_type, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks +): + blocks = { + "in_memory": in_memory_blocks, + "memory_mapped": memory_mapped_blocks, + "mixed": mixed_in_memory_and_memory_mapped_blocks, + }[blocks_type] + table = ConcatenationTable.from_blocks(blocks) + pickled_table = pickle.dumps(table) + unpickled_table = pickle.loads(pickled_table) + assert unpickled_table.table == table.table + assert unpickled_table.blocks == table.blocks + assert_index_attributes_equal(table, unpickled_table) + + @pytest.mark.parametrize("blocks_type", ["in_memory", "memory_mapped", "mixed"]) def test_concatenation_table_slice( blocks_type, in_memory_pa_table, in_memory_blocks, memory_mapped_blocks, mixed_in_memory_and_memory_mapped_blocks