Skip to content

Commit

Permalink
Remove assert uses from playwright tests to fix flakiness (streamli…
Browse files Browse the repository at this point in the history
…t#8062)

* Remove uses of `assert` in playwright tests to fix flakiness

* Remove more asserts

* Mini fix

* Fix test

* Run wait check in one function
  • Loading branch information
lukasmasuch authored Feb 1, 2024
1 parent c1535f0 commit 516e67c
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 79 deletions.
53 changes: 53 additions & 0 deletions e2e_playwright/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,3 +542,56 @@ def compare(

if test_failure_messages:
pytest.fail("Missing snapshots: \n" + "\n".join(test_failure_messages))


def wait_until(page: Page, fn: callable, timeout: int = 5000, interval: int = 100):
"""Run a test function in a loop until it evaluates to True
or times out.
For example:
>>> wait_until(lambda: x.values() == ['x'], page)
Parameters
----------
page : playwright.sync_api.Page
Playwright page
fn : callable
Callback
timeout : int, optional
Total timeout in milliseconds, by default 5000
interval : int, optional
Waiting interval, by default 100
Adapted from panel.
"""
# Hide this function traceback from the pytest output if the test fails
__tracebackhide__ = True

start = time.time()

def timed_out():
elapsed = time.time() - start
elapsed_ms = elapsed * 1000
return elapsed_ms > timeout

timeout_msg = f"wait_until timed out in {timeout} milliseconds"

while True:
try:
result = fn()
except AssertionError as e:
if timed_out():
raise TimeoutError(timeout_msg) from e
else:
if result not in (None, True, False):
raise ValueError(
"`wait_until` callback must return None, True or "
f"False, returned {result!r}"
)
# Stop is result is True or None
# None is returned when the function has an assert
if result is None or result:
return
if timed_out():
raise TimeoutError(timeout_msg)
page.wait_for_timeout(interval)
16 changes: 9 additions & 7 deletions e2e_playwright/fast_rerun_safety_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,18 @@

from playwright.sync_api import Page, expect

from e2e_playwright.conftest import wait_for_app_run


def test_no_concurrent_changes(app: Page):
counters = app.locator(".stMarkdown")
counters = app.get_by_test_id("stMarkdown")
expect(counters.first).to_have_text("0", use_inner_text=True)

button = app.locator(".stButton")
button = app.get_by_test_id("stButton")
button.first.click()
app.wait_for_timeout(300)
wait_for_app_run(app)

counters = app.locator(".stMarkdown")
c1 = counters.nth(0).inner_text()
c2 = counters.nth(1).inner_text()
assert c1 == c2
counters = app.get_by_test_id("stMarkdown")
expect(counters.nth(0)).to_have_text(
counters.nth(1).inner_text(), use_inner_text=True
)
36 changes: 23 additions & 13 deletions e2e_playwright/label_markdown_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ def test_radio_labels_handle_markdown(app: Page, assert_snapshot: ImageCompareFu
["valid", "link"],
]

radioes = app.get_by_test_id("stRadio")
expect(radioes).to_have_count(4)
radios = app.get_by_test_id("stRadio")
expect(radios).to_have_count(4)
for index, case in enumerate(cases):
assert_snapshot(
radioes.nth(index).get_by_test_id("stWidgetLabel"),
radios.nth(index).get_by_test_id("stWidgetLabel"),
name=f"st_radio-{case[0]}_{case[1]}",
)

Expand Down Expand Up @@ -158,7 +158,8 @@ def test_text_input_labels_handle_markdown(
]

text_inputs = app.get_by_test_id("stTextInput")
assert text_inputs.count() == 4
expect(text_inputs).to_have_count(4)

for index, case in enumerate(cases):
assert_snapshot(
text_inputs.nth(index).get_by_test_id("stWidgetLabel"),
Expand All @@ -177,7 +178,8 @@ def test_number_input_labels_handle_markdown(
]

number_inputs = app.get_by_test_id("stNumberInput")
assert number_inputs.count() == 4
expect(number_inputs).to_have_count(4)

for index, case in enumerate(cases):
assert_snapshot(
number_inputs.nth(index).get_by_test_id("stWidgetLabel"),
Expand All @@ -196,7 +198,8 @@ def test_text_area_labels_handle_markdown(
]

text_areas = app.get_by_test_id("stTextArea")
assert text_areas.count() == 4
expect(text_areas).to_have_count(4)

for index, case in enumerate(cases):
assert_snapshot(
text_areas.nth(index).get_by_test_id("stWidgetLabel"),
Expand All @@ -215,7 +218,8 @@ def test_date_input_labels_handle_markdown(
]

date_inputs = app.get_by_test_id("stDateInput")
assert date_inputs.count() == 4
expect(date_inputs).to_have_count(4)

for index, case in enumerate(cases):
assert_snapshot(
date_inputs.nth(index).get_by_test_id("stWidgetLabel"),
Expand All @@ -234,7 +238,8 @@ def test_time_input_labels_handle_markdown(
]

time_inputs = app.get_by_test_id("stTimeInput")
assert time_inputs.count() == 4
expect(time_inputs).to_have_count(4)

for index, case in enumerate(cases):
assert_snapshot(
time_inputs.nth(index).get_by_test_id("stWidgetLabel"),
Expand All @@ -253,7 +258,8 @@ def test_file_uploader_labels_handle_markdown(
]

file_uploaders = app.get_by_test_id("stFileUploader")
assert file_uploaders.count() == 4
expect(file_uploaders).to_have_count(4)

for index, case in enumerate(cases):
assert_snapshot(
file_uploaders.nth(index).get_by_test_id("stWidgetLabel"),
Expand All @@ -272,7 +278,8 @@ def test_color_picker_labels_handle_markdown(
]

color_pickers = app.get_by_test_id("stColorPicker")
assert color_pickers.count() == 4
expect(color_pickers).to_have_count(4)

for index, case in enumerate(cases):
assert_snapshot(
color_pickers.nth(index).get_by_test_id("stWidgetLabel"),
Expand All @@ -291,7 +298,8 @@ def test_metric_labels_handle_markdown(
]

metrics = app.get_by_test_id("stMetric")
assert metrics.count() == 4
expect(metrics).to_have_count(4)

for index, case in enumerate(cases):
assert_snapshot(
metrics.nth(index).get_by_test_id("stMetricLabel"),
Expand All @@ -310,7 +318,8 @@ def test_expander_labels_handle_markdown(
]

expanders = app.get_by_test_id("stExpander")
assert expanders.count() == 4
expect(expanders).to_have_count(4)

for index, case in enumerate(cases):
assert_snapshot(
expanders.nth(index),
Expand All @@ -327,7 +336,8 @@ def test_tab_labels_handle_markdown(app: Page, assert_snapshot: ImageCompareFunc
]

tabs = app.get_by_test_id("stTab")
assert tabs.count() == 4
expect(tabs).to_have_count(4)

for index, case in enumerate(cases):
assert_snapshot(
tabs.nth(index),
Expand Down
11 changes: 4 additions & 7 deletions e2e_playwright/multipage_apps/mpa_basics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,8 @@ def test_switch_page_removes_query_params(page: Page, app_port: int):
# Trigger st.switch_page
page.get_by_test_id("stButton").locator("button").first.click()
wait_for_app_loaded(page)

# Check that query params don't persist
assert page.url == f"http://localhost:{app_port}/page2"
expect(page).to_have_url(f"http://localhost:{app_port}/page2")


def test_removes_query_params_when_swapping_pages(page: Page, app_port: int):
Expand All @@ -174,8 +173,7 @@ def test_removes_query_params_when_swapping_pages(page: Page, app_port: int):

page.get_by_test_id("stSidebarNav").locator("a").nth(2).click()
wait_for_app_loaded(page)

assert page.url == f"http://localhost:{app_port}/page3"
expect(page).to_have_url(f"http://localhost:{app_port}/page3")


def test_removes_non_embed_query_params_when_swapping_pages(page: Page, app_port: int):
Expand All @@ -189,7 +187,6 @@ def test_removes_non_embed_query_params_when_swapping_pages(page: Page, app_port
page.get_by_test_id("stSidebarNav").locator("a").nth(2).click()
wait_for_app_loaded(page)

assert (
page.url
== f"http://localhost:{app_port}/page3?embed=true&embed_options=show_toolbar&embed_options=show_colored_line"
expect(page).to_have_url(
f"http://localhost:{app_port}/page3?embed=true&embed_options=show_toolbar&embed_options=show_colored_line"
)
4 changes: 2 additions & 2 deletions e2e_playwright/st_audio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re

from playwright.sync_api import Page, expect

Expand All @@ -19,5 +20,4 @@ def test_audio_has_correct_properties(app: Page):
audio = app.get_by_test_id("stAudio")
expect(audio).to_be_visible()
expect(audio).to_have_attribute("controls", "")
# the src attribute will change based off of machine so do a non None check
assert audio.get_attribute("src") != None
expect(audio).to_have_attribute("src", re.compile(r".*media.*wav"))
13 changes: 5 additions & 8 deletions e2e_playwright/st_expander_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,10 @@ def test_expander_session_state_set(app: Page):
app.get_by_text("Print State Value").click()
wait_for_app_run(app)

text_elements = app.get_by_test_id("stText")

text_elements = app.locator("[data-testid='stText']")
expect(text_elements).to_have_count(2)
text_elements = text_elements.all_inner_texts()
texts = [text.strip() for text in text_elements]

expected = [
"0.0",
"0.0",
]
assert texts == expected

expect(text_elements.nth(0)).to_have_text("0.0", use_inner_text=True)
expect(text_elements.nth(1)).to_have_text("0.0", use_inner_text=True)
27 changes: 16 additions & 11 deletions e2e_playwright/st_graphviz_chart_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,24 @@

from playwright.sync_api import Page, expect

from e2e_playwright.conftest import ImageCompareFunction, wait_for_app_run
from e2e_playwright.conftest import ImageCompareFunction, wait_for_app_run, wait_until


def get_first_graph_svg(app: Page):
return app.locator(".stGraphVizChart > svg").nth(0)
return app.get_by_test_id("stGraphVizChart").nth(0).locator("svg")


def click_fullscreen(app: Page):
app.locator('[data-testid="StyledFullScreenButton"]').nth(0).click()
app.get_by_test_id("StyledFullScreenButton").nth(0).click()
# Wait for the animation to finish
app.wait_for_timeout(1000)


def test_initial_setup(app: Page):
"""Initial setup: ensure charts are loaded."""
expect(app.locator(".stGraphVizChart > svg > g > title")).to_have_count(6)
expect(
app.get_by_test_id("stGraphVizChart").locator("svg > g > title")
).to_have_count(6)


def test_shows_left_and_right_graph(app: Page):
Expand All @@ -55,7 +57,7 @@ def test_first_graph_fullscreen(app: Page, assert_snapshot: ImageCompareFunction
"""Test if the first graph shows in fullscreen."""

# Hover over the parent div
app.locator(".stGraphVizChart").nth(0).hover()
app.get_by_test_id("stGraphVizChart").nth(0).hover()

# Enter fullscreen
click_fullscreen(app)
Expand All @@ -65,9 +67,12 @@ def test_first_graph_fullscreen(app: Page, assert_snapshot: ImageCompareFunction
expect(first_graph_svg).not_to_have_attribute("width", "79pt")
expect(first_graph_svg).not_to_have_attribute("height", "116pt")

svg_dimensions = first_graph_svg.bounding_box()
assert svg_dimensions["width"] == 1256
assert svg_dimensions["height"] == 662
def check_dimensions():
svg_dimensions = first_graph_svg.bounding_box()
return svg_dimensions["width"] == 1256 and svg_dimensions["height"] == 662

wait_until(app, check_dimensions)

assert_snapshot(first_graph_svg, name="graphviz_fullscreen")


Expand All @@ -77,7 +82,7 @@ def test_first_graph_after_exit_fullscreen(
"""Test if the first graph has correct size after exiting fullscreen."""

# Hover over the parent div
app.locator(".stGraphVizChart").nth(0).hover()
app.get_by_test_id("stGraphVizChart").nth(0).hover()

# Enter and exit fullscreen
click_fullscreen(app)
Expand All @@ -104,7 +109,7 @@ def test_renders_with_specified_engines(
expect(app.get_by_test_id("stMarkdown").nth(0)).to_have_text(engine)

assert_snapshot(
app.locator(".stGraphVizChart > svg").nth(2),
app.get_by_test_id("stGraphVizChart").nth(2).locator("svg"),
name=f"st_graphviz_chart_engine-{engine}",
)

Expand All @@ -116,6 +121,6 @@ def test_dot_string(app: Page, assert_snapshot: ImageCompareFunction):
expect(title).to_have_text("Dot")

assert_snapshot(
app.locator(".stGraphVizChart > svg").nth(5),
app.get_by_test_id("stGraphVizChart").nth(5).locator("svg"),
name="st_graphviz_chart_dot_string",
)
Loading

0 comments on commit 516e67c

Please sign in to comment.