Skip to content

Commit

Permalink
Improve error message in st.experimental_memo when it returns an unev…
Browse files Browse the repository at this point in the history
…aluated dataframe (streamlit#5515)
  • Loading branch information
sfc-gh-tszerszen authored Oct 14, 2022
1 parent 5f23c89 commit 3f78590
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
8 changes: 4 additions & 4 deletions lib/streamlit/runtime/caching/cache_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
)


def _get_cached_func_name_md(func) -> str:
def get_cached_func_name_md(func) -> str:
"""Get markdown representation of the function name."""
if hasattr(func, "__name__"):
return "`%s()`" % func.__name__
Expand All @@ -36,7 +36,7 @@ def _get_cached_func_name_md(func) -> str:
def get_return_value_type(return_value) -> str:
if hasattr(return_value, "__module__") and hasattr(type(return_value), "__name__"):
return f"`{return_value.__module__}.{type(return_value).__name__}`"
return _get_cached_func_name_md(return_value)
return get_cached_func_name_md(return_value)


class CacheType(enum.Enum):
Expand Down Expand Up @@ -141,7 +141,7 @@ def __init__(
cache_type: CacheType,
cached_func: types.FunctionType,
):
func_name = _get_cached_func_name_md(cached_func)
func_name = get_cached_func_name_md(cached_func)
decorator_name = (cache_type.value,)

msg = (
Expand All @@ -165,7 +165,7 @@ def __init__(self, func: types.FunctionType, return_value: types.FunctionType):
MarkdownFormattedException.__init__(
self,
f"""
Cannot serialize the return value (of type {get_return_value_type(return_value)}) in {_get_cached_func_name_md(func)}.
Cannot serialize the return value (of type {get_return_value_type(return_value)}) in {get_cached_func_name_md(func)}.
`st.experimental_memo` uses [pickle](https://docs.python.org/3/library/pickle.html) to
serialize the function’s return value and safely store it in the cache without mutating the original object. Please convert the return value to a pickle-serializable type.
If you want to cache unserializable objects such as database connections or Tensorflow
Expand Down
19 changes: 18 additions & 1 deletion lib/streamlit/runtime/caching/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@
from google.protobuf.message import Message

import streamlit as st
from streamlit import util
from streamlit import type_util, util
from streamlit.elements import NONWIDGET_ELEMENTS
from streamlit.elements.spinner import spinner
from streamlit.errors import StreamlitAPIException
from streamlit.logger import get_logger
from streamlit.proto.Block_pb2 import Block
from streamlit.runtime.caching.cache_errors import (
Expand All @@ -51,6 +52,7 @@
UnhashableParamError,
UnhashableTypeError,
UnserializableReturnValueError,
get_cached_func_name_md,
)
from streamlit.runtime.caching.hashing import update_hash

Expand Down Expand Up @@ -253,6 +255,21 @@ def get_or_create_cached_value():
try:
cache.write_result(value_key, return_value, messages)
except TypeError:
if type_util.is_type(
return_value, "snowflake.snowpark.dataframe.DataFrame"
):

class UnevaluatedDataFrameError(StreamlitAPIException):
def __init__(self):
super().__init__(
self,
f"""
The function {get_cached_func_name_md(func)} is decorated with `st.experimental_memo` but it returns an unevaluated
dataframe of type `snowflake.snowpark.DataFrame`. Please call `collect()` or `to_pandas()` on the
dataframe before returning it, so `st.experimental_memo` can serialize and cache it.""",
)

raise UnevaluatedDataFrameError
raise UnserializableReturnValueError(
return_value=return_value, func=cached_func.func
)
Expand Down

0 comments on commit 3f78590

Please sign in to comment.