Skip to content

Commit

Permalink
Add option to apply automatic dataframe truncation (streamlit#8002)
Browse files Browse the repository at this point in the history
* Add prototype for auto dataframe truncation

* Update solution

* Fix nbytes

* Clean up dimensions

* Refactor data dimensions

* Update warning message

* Removed unused import

* Add comment

* Fix test

* Add unit test

* Update tests

* Finalize e2e tests

* Rename to arrow truncation

* Add license header

* Rename

* Update decorator usage

* Double the rows

* Temp: try with dataframe only

* Change count

* Fix test

* Finalize e2e test

* Fix tests

* Add some prints

* Add more prints

* Fix copy paste error

* Some refactoring

* Fix type util test

* Use higher pyarrow version

* Fix type util

* Update pyarrow min constraint

* Fix warning

* Apply more updates

* Remove comment

* Fix issue

* Remove print

* Apply comment
  • Loading branch information
lukasmasuch authored Jan 26, 2024
1 parent f4684e2 commit 2a29dab
Show file tree
Hide file tree
Showing 10 changed files with 285 additions and 33 deletions.
24 changes: 24 additions & 0 deletions e2e_playwright/config_arrow_truncation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2024)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pandas as pd

import streamlit as st

np.random.seed(0)

df = pd.DataFrame(np.random.randn(50000, 20), columns=("col %d" % i for i in range(20)))

st.dataframe(df)
37 changes: 37 additions & 0 deletions e2e_playwright/config_arrow_truncation_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2024)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import pytest
from playwright.sync_api import Page, expect


@pytest.fixture(scope="module")
@pytest.mark.early
def configure_arrow_truncation():
"""Configure arrow truncation and max message size."""
os.environ["STREAMLIT_SERVER_ENABLE_ARROW_TRUNCATION"] = "True"
os.environ["STREAMLIT_SERVER_MAX_MESSAGE_SIZE"] = "3"
yield
del os.environ["STREAMLIT_SERVER_ENABLE_ARROW_TRUNCATION"]
del os.environ["STREAMLIT_SERVER_MAX_MESSAGE_SIZE"]


def test_shows_limitation_message(app: Page, configure_arrow_truncation):
caption_elements = app.get_by_test_id("stCaptionContainer")
expect(caption_elements).to_have_count(1)
expect(caption_elements.nth(0)).to_have_text(
"⚠️ Showing 12k out of 50k rows due to data size limitations. "
)
6 changes: 3 additions & 3 deletions frontend/lib/src/components/widgets/DataFrame/arrowUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,9 @@ export function getEmptyIndexColumn(): BaseColumnProps {
export function getAllColumnsFromArrow(data: Quiver): BaseColumnProps[] {
const columns: BaseColumnProps[] = []

// TODO(lukasmasuch): use data.dimensions instead here?
const numIndices = data.types?.index?.length ?? 0
const numColumns = data.columns?.[0]?.length ?? 0
const { dimensions } = data
const numIndices = dimensions.headerColumns
const numColumns = dimensions.dataColumns

if (numIndices === 0 && numColumns === 0) {
// Tables that don't have any columns cause an exception in glide-data-grid.
Expand Down
32 changes: 4 additions & 28 deletions frontend/lib/src/dataframes/Quiver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1203,34 +1203,10 @@ but was expecting \`${JSON.stringify(expectedIndexTypes)}\`.

/** The DataFrame's dimensions. */
public get dimensions(): DataFrameDimensions {
// TODO(lukasmasuch): this._index[0].length can be 0 if there are rows
// but only an empty index. Probably not the best way to cross check the number
// of rows.
const [headerColumns, dataRowsCheck] = this._index.length
? [this._index.length, this._index[0].length]
: [1, 0]

const [headerRows, dataColumnsCheck] = this._columns.length
? [this._columns.length, this._columns[0].length]
: [1, 0]

const [dataRows, dataColumns] = this._data.numRows
? [this._data.numRows, this._data.numCols]
: // If there is no data, default to the number of header columns.
[0, dataColumnsCheck]

// Sanity check: ensure the schema is not messed up. If this happens,
// something screwy probably happened in addRows.
if (
(dataRows !== 0 && dataRows !== dataRowsCheck) ||
(dataColumns !== 0 && dataColumns !== dataColumnsCheck)
) {
throw new Error(
"Dataframe dimensions don't align: " +
`rows(${dataRows} != ${dataRowsCheck}) OR ` +
`cols(${dataColumns} != ${dataColumnsCheck})`
)
}
const headerColumns = this._index.length || this.types.index.length || 1
const headerRows = this._columns.length || 1
const dataRows = this._data.numRows || 0
const dataColumns = this._data.numCols || this._columns?.[0]?.length || 0

const rows = headerRows + dataRows
const columns = headerColumns + dataColumns
Expand Down
2 changes: 1 addition & 1 deletion lib/min-constraints-gen.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ packaging==16.8
pandas==1.3.0
pillow==7.1.0
protobuf==3.20
pyarrow==6.0
pyarrow==7.0
pydeck==0.8.0b4
python-dateutil==2.7.3
requests==2.27
Expand Down
2 changes: 1 addition & 1 deletion lib/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
# pyarrow is not semantically versioned, gets new major versions frequently, and
# doesn't tend to break the API on major version upgrades, so we don't put an
# upper bound on it.
"pyarrow>=6.0",
"pyarrow>=7.0",
"python-dateutil>=2.7.3, <3",
"requests>=2.27, <3",
"rich>=10.14.0, <14",
Expand Down
12 changes: 12 additions & 0 deletions lib/streamlit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,18 @@ def _server_address() -> Optional[str]:
type_=int,
)

_create_option(
"server.enableArrowTruncation",
description="""
Enable automatically truncating all data structures that get serialized into Arrow (e.g. DataFrames)
to ensure that the size is under `server.maxMessageSize`.
""",
visibility="hidden",
default_val=False,
scriptable=True,
type_=bool,
)

_create_option(
"server.enableWebsocketCompression",
description="""
Expand Down
91 changes: 91 additions & 0 deletions lib/streamlit/type_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import contextlib
import copy
import math
import re
import types
from enum import Enum, EnumMeta, auto
Expand Down Expand Up @@ -712,6 +713,82 @@ def is_pyarrow_version_less_than(v: str) -> bool:
return version.parse(pa.__version__) < version.parse(v)


def _maybe_truncate_table(
table: pa.Table, truncated_rows: int | None = None
) -> pa.Table:
"""Experimental feature to automatically truncate tables that
are larger than the maximum allowed message size. It needs to be enabled
via the server.enableArrowTruncation config option.
Parameters
----------
table : pyarrow.Table
A table to truncate.
truncated_rows : int or None
The number of rows that have been truncated so far. This is used by
the recursion logic to keep track of the total number of truncated
rows.
"""

if config.get_option("server.enableArrowTruncation"):
# This is an optimization problem: We don't know at what row
# the perfect cut-off is to comply with the max size. But we want to figure
# it out in as few iterations as possible. We almost always will cut out
# more than required to keep the iterations low.

# The maximum size allowed for protobuf messages in bytes:
max_message_size = int(config.get_option("server.maxMessageSize") * 1e6)
# We add 1 MB for other overhead related to the protobuf message.
# This is a very conservative estimate, but it should be good enough.
table_size = int(table.nbytes + 1 * 1e6)
table_rows = table.num_rows

if table_rows > 1 and table_size > max_message_size:
# targeted rows == the number of rows the table should be truncated to.
# Calculate an approximation of how many rows we need to truncate to.
targeted_rows = math.ceil(table_rows * (max_message_size / table_size))
# Make sure to cut out at least a couple of rows to avoid running
# this logic too often since it is quite inefficient and could lead
# to infinity recursions without these precautions.
targeted_rows = math.floor(
max(
min(
# Cut out:
# an additional 5% of the estimated num rows to cut out:
targeted_rows - math.floor((table_rows - targeted_rows) * 0.05),
# at least 1% of table size:
table_rows - (table_rows * 0.01),
# at least 5 rows:
table_rows - 5,
),
1, # but it should always have at least 1 row
)
)
sliced_table = table.slice(0, targeted_rows)
return _maybe_truncate_table(
sliced_table, (truncated_rows or 0) + (table_rows - targeted_rows)
)

if truncated_rows:
displayed_rows = string_util.simplify_number(table.num_rows)
total_rows = string_util.simplify_number(table.num_rows + truncated_rows)

if displayed_rows == total_rows:
# If the simplified numbers are the same,
# we just display the exact numbers.
displayed_rows = str(table.num_rows)
total_rows = str(table.num_rows + truncated_rows)

st.caption(
f"⚠️ Showing {displayed_rows} out of {total_rows} "
"rows due to data size limitations."
)

return table


def pyarrow_table_to_bytes(table: pa.Table) -> bytes:
"""Serialize pyarrow.Table to bytes using Apache Arrow.
Expand All @@ -721,6 +798,20 @@ def pyarrow_table_to_bytes(table: pa.Table) -> bytes:
A table to convert.
"""
try:
table = _maybe_truncate_table(table)
except RecursionError as err:
# This is a very unlikely edge case, but we want to make sure that
# it doesn't lead to unexpected behavior.
# If there is a recursion error, we just return the table as-is
# which will lead to the normal message limit exceed error.
_LOGGER.warning(
"Recursion error while truncating Arrow table. This is not "
"supposed to happen.",
exc_info=err,
)

# Convert table to bytes
sink = pa.BufferOutputStream()
writer = pa.RecordBatchStreamWriter(sink, table.schema)
writer.write_table(table)
Expand Down
1 change: 1 addition & 0 deletions lib/tests/streamlit/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ def test_config_option_keys(self):
"server.maxUploadSize",
"server.maxMessageSize",
"server.enableStaticServing",
"server.enableArrowTruncation",
"server.sslCertFile",
"server.sslKeyFile",
"ui.hideTopBar",
Expand Down
Loading

0 comments on commit 2a29dab

Please sign in to comment.