Skip to content

Commit

Permalink
add computed backend vars (reflex-dev#3573)
Browse files Browse the repository at this point in the history
* add computed backend vars

* finish computed backend vars, add tests

* fix token for AppHarness with redis state manager

* fix timing issues

* add unit tests for computed backend vars

* automagically mark cvs with _ prefix as backend var

* fully migrate backend computed vars

* rename is_backend_variable to is_backend_base_variable

* add integration test for implicit backend cv, adjust comments

* replace expensive backend var check at runtime

* keep stuff together

* simplify backend var check method, consistent naming, improve test typing

* fix: do not convert properties to cvs

* add test for property

* fix cached_properties with _ prefix in state cls
  • Loading branch information
benedikt-bartscher authored Jun 29, 2024
1 parent bcc7a61 commit b7651e2
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 49 deletions.
37 changes: 36 additions & 1 deletion integration/test_computed_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ class State(StateMixin, rx.State):
def count1(self) -> int:
return self.count

# cached backend var with dep on count
@rx.var(cache=True, interval=15, backend=True)
def count1_backend(self) -> int:
return self.count

# same as above but implicit backend with `_` prefix
@rx.var(cache=True, interval=15)
def _count1_backend(self) -> int:
return self.count

# explicit disabled auto_deps
@rx.var(interval=15, cache=True, auto_deps=False)
def count3(self) -> int:
Expand Down Expand Up @@ -70,6 +80,10 @@ def index() -> rx.Component:
rx.text(State.count, id="count"),
rx.text("count1:"),
rx.text(State.count1, id="count1"),
rx.text("count1_backend:"),
rx.text(State.count1_backend, id="count1_backend"),
rx.text("_count1_backend:"),
rx.text(State._count1_backend, id="_count1_backend"),
rx.text("count3:"),
rx.text(State.count3, id="count3"),
rx.text("depends_on_count:"),
Expand Down Expand Up @@ -154,7 +168,8 @@ def token(computed_vars: AppHarness, driver: WebDriver) -> str:
return token


def test_computed_vars(
@pytest.mark.asyncio
async def test_computed_vars(
computed_vars: AppHarness,
driver: WebDriver,
token: str,
Expand All @@ -168,6 +183,20 @@ def test_computed_vars(
"""
assert computed_vars.app_instance is not None

token = f"{token}_state.state"
state = (await computed_vars.get_state(token)).substates["state"]
assert state is not None
assert state.count1_backend == 0
assert state._count1_backend == 0

# test that backend var is not rendered
count1_backend = driver.find_element(By.ID, "count1_backend")
assert count1_backend
assert count1_backend.text == ""
_count1_backend = driver.find_element(By.ID, "_count1_backend")
assert _count1_backend
assert _count1_backend.text == ""

count = driver.find_element(By.ID, "count")
assert count
assert count.text == "0"
Expand Down Expand Up @@ -207,6 +236,12 @@ def test_computed_vars(
computed_vars.poll_for_content(depends_on_count, timeout=2, exp_not_equal="0")
== "1"
)
state = (await computed_vars.get_state(token)).substates["state"]
assert state is not None
assert state.count1_backend == 1
assert count1_backend.text == ""
assert state._count1_backend == 1
assert _count1_backend.text == ""

mark_dirty.click()
with pytest.raises(TimeoutError):
Expand Down
42 changes: 22 additions & 20 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,10 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# Vars inherited by the parent state.
inherited_vars: ClassVar[Dict[str, Var]] = {}

# Backend vars that are never sent to the client.
# Backend base vars that are never sent to the client.
backend_vars: ClassVar[Dict[str, Any]] = {}

# Backend vars inherited
# Backend base vars inherited
inherited_backend_vars: ClassVar[Dict[str, Any]] = {}

# The event handlers.
Expand Down Expand Up @@ -344,7 +344,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# The routing path that triggered the state
router_data: Dict[str, Any] = {}

# Per-instance copy of backend variable values
# Per-instance copy of backend base variable values
_backend_vars: Dict[str, Any] = {}

# The router data for the current page
Expand Down Expand Up @@ -492,21 +492,12 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs):
new_backend_vars = {
name: value
for name, value in cls.__dict__.items()
if types.is_backend_variable(name, cls)
}

# Get backend computed vars
backend_computed_vars = {
v._var_name: v._var_set_state(cls)
for v in computed_vars
if types.is_backend_variable(v._var_name, cls)
and v._var_name not in cls.inherited_backend_vars
if types.is_backend_base_variable(name, cls)
}

cls.backend_vars = {
**cls.inherited_backend_vars,
**new_backend_vars,
**backend_computed_vars,
}

# Set the base and computed vars.
Expand Down Expand Up @@ -548,7 +539,7 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs):
cls.computed_vars[newcv._var_name] = newcv
cls.vars[newcv._var_name] = newcv
continue
if types.is_backend_variable(name, mixin):
if types.is_backend_base_variable(name, mixin):
cls.backend_vars[name] = copy.deepcopy(value)
continue
if events.get(name) is not None:
Expand Down Expand Up @@ -1087,7 +1078,7 @@ def __setattr__(self, name: str, value: Any):
setattr(self.parent_state, name, value)
return

if types.is_backend_variable(name, type(self)):
if name in self.backend_vars:
self._backend_vars.__setitem__(name, value)
self.dirty_vars.add(name)
self._mark_dirty()
Expand Down Expand Up @@ -1538,11 +1529,14 @@ def _expired_computed_vars(self) -> set[str]:
if self.computed_vars[cvar].needs_update(instance=self)
)

def _dirty_computed_vars(self, from_vars: set[str] | None = None) -> set[str]:
def _dirty_computed_vars(
self, from_vars: set[str] | None = None, include_backend: bool = True
) -> set[str]:
"""Determine ComputedVars that need to be recalculated based on the given vars.
Args:
from_vars: find ComputedVar that depend on this set of vars. If unspecified, will use the dirty_vars.
include_backend: whether to include backend vars in the calculation.
Returns:
Set of computed vars to include in the delta.
Expand All @@ -1551,6 +1545,7 @@ def _dirty_computed_vars(self, from_vars: set[str] | None = None) -> set[str]:
cvar
for dirty_var in from_vars or self.dirty_vars
for cvar in self._computed_var_dependencies[dirty_var]
if include_backend or not self.computed_vars[cvar]._backend
)

@classmethod
Expand Down Expand Up @@ -1586,19 +1581,23 @@ def get_delta(self) -> Delta:
self.dirty_vars.update(self._always_dirty_computed_vars)
self._mark_dirty()

frontend_computed_vars: set[str] = {
name for name, cv in self.computed_vars.items() if not cv._backend
}

# Return the dirty vars for this instance, any cached/dependent computed vars,
# and always dirty computed vars (cache=False)
delta_vars = (
self.dirty_vars.intersection(self.base_vars)
.union(self.dirty_vars.intersection(self.computed_vars))
.union(self._dirty_computed_vars())
.union(self.dirty_vars.intersection(frontend_computed_vars))
.union(self._dirty_computed_vars(include_backend=False))
.union(self._always_dirty_computed_vars)
)

subdelta = {
prop: getattr(self, prop)
for prop in delta_vars
if not types.is_backend_variable(prop, type(self))
if not types.is_backend_base_variable(prop, type(self))
}
if len(subdelta) > 0:
delta[self.get_full_name()] = subdelta
Expand Down Expand Up @@ -1727,12 +1726,14 @@ def dict(
else self.get_value(getattr(self, prop_name))
)
for prop_name, cv in self.computed_vars.items()
if not cv._backend
}
elif include_computed:
computed_vars = {
# Include the computed vars.
prop_name: self.get_value(getattr(self, prop_name))
for prop_name in self.computed_vars
for prop_name, cv in self.computed_vars.items()
if not cv._backend
}
else:
computed_vars = {}
Expand All @@ -1745,6 +1746,7 @@ def dict(
for v in self.substates.values()
]:
d.update(substate_d)

return d

async def __aenter__(self) -> BaseState:
Expand Down
45 changes: 22 additions & 23 deletions reflex/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import inspect
import sys
import types
from functools import wraps
from functools import cached_property, wraps
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -410,7 +410,7 @@ def is_valid_var_type(type_: Type) -> bool:
return _issubclass(type_, StateVar) or serializers.has_serializer(type_)


def is_backend_variable(name: str, cls: Type | None = None) -> bool:
def is_backend_base_variable(name: str, cls: Type) -> bool:
"""Check if this variable name correspond to a backend variable.
Args:
Expand All @@ -429,31 +429,30 @@ def is_backend_variable(name: str, cls: Type | None = None) -> bool:
if name.startswith("__"):
return False

if cls is not None:
if name.startswith(f"_{cls.__name__}__"):
if name.startswith(f"_{cls.__name__}__"):
return False

hints = get_type_hints(cls)
if name in hints:
hint = get_origin(hints[name])
if hint == ClassVar:
return False
hints = get_type_hints(cls)
if name in hints:
hint = get_origin(hints[name])
if hint == ClassVar:
return False

if name in cls.inherited_backend_vars:
if name in cls.inherited_backend_vars:
return False

if name in cls.__dict__:
value = cls.__dict__[name]
if type(value) == classmethod:
return False
if callable(value):
return False
from reflex.vars import ComputedVar

if name in cls.__dict__:
value = cls.__dict__[name]
if type(value) == classmethod:
return False
if callable(value):
return False
if isinstance(value, types.FunctionType):
return False
# enable after #3573 is merged
# from reflex.vars import ComputedVar
#
# if isinstance(value, ComputedVar):
# return False
if isinstance(
value, (types.FunctionType, property, cached_property, ComputedVar)
):
return False

return True

Expand Down
13 changes: 13 additions & 0 deletions reflex/vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -1944,6 +1944,9 @@ class ComputedVar(Var, property):
# Whether to track dependencies and cache computed values
_cache: bool = dataclasses.field(default=False)

# Whether the computed var is a backend var
_backend: bool = dataclasses.field(default=False)

# The initial value of the computed var
_initial_value: Any | types.Unset = dataclasses.field(default=types.Unset())

Expand All @@ -1964,6 +1967,7 @@ def __init__(
deps: Optional[List[Union[str, Var]]] = None,
auto_deps: bool = True,
interval: Optional[Union[int, datetime.timedelta]] = None,
backend: bool | None = None,
**kwargs,
):
"""Initialize a ComputedVar.
Expand All @@ -1975,11 +1979,16 @@ def __init__(
deps: Explicit var dependencies to track.
auto_deps: Whether var dependencies should be auto-determined.
interval: Interval at which the computed var should be updated.
backend: Whether the computed var is a backend var.
**kwargs: additional attributes to set on the instance
Raises:
TypeError: If the computed var dependencies are not Var instances or var names.
"""
if backend is None:
backend = fget.__name__.startswith("_")
self._backend = backend

self._initial_value = initial_value
self._cache = cache
if isinstance(interval, int):
Expand Down Expand Up @@ -2023,6 +2032,7 @@ def _replace(self, merge_var_data=None, **kwargs: Any) -> ComputedVar:
deps=kwargs.get("deps", self._static_deps),
auto_deps=kwargs.get("auto_deps", self._auto_deps),
interval=kwargs.get("interval", self._update_interval),
backend=kwargs.get("backend", self._backend),
_var_name=kwargs.get("_var_name", self._var_name),
_var_type=kwargs.get("_var_type", self._var_type),
_var_is_local=kwargs.get("_var_is_local", self._var_is_local),
Expand Down Expand Up @@ -2233,6 +2243,7 @@ def computed_var(
deps: Optional[List[Union[str, Var]]] = None,
auto_deps: bool = True,
interval: Optional[Union[datetime.timedelta, int]] = None,
backend: bool | None = None,
_deprecated_cached_var: bool = False,
**kwargs,
) -> ComputedVar | Callable[[Callable[[BaseState], Any]], ComputedVar]:
Expand All @@ -2245,6 +2256,7 @@ def computed_var(
deps: Explicit var dependencies to track.
auto_deps: Whether var dependencies should be auto-determined.
interval: Interval at which the computed var should be updated.
backend: Whether the computed var is a backend var.
_deprecated_cached_var: Indicate usage of deprecated cached_var partial function.
**kwargs: additional attributes to set on the instance
Expand Down Expand Up @@ -2280,6 +2292,7 @@ def wrapper(fget: Callable[[BaseState], Any]) -> ComputedVar:
deps=deps,
auto_deps=auto_deps,
interval=interval,
backend=backend,
**kwargs,
)

Expand Down
Loading

0 comments on commit b7651e2

Please sign in to comment.