Skip to content

Commit

Permalink
fix streamlit imports
Browse files Browse the repository at this point in the history
  • Loading branch information
MaximeLutel committed Nov 10, 2022
1 parent 33d4c78 commit ad0192d
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 102 deletions.
2 changes: 1 addition & 1 deletion streamlit_prophet/app/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import sys

from streamlit import cli
from streamlit.web import cli


def deploy_streamlit():
Expand Down
60 changes: 29 additions & 31 deletions tests/dataprep/test_format.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import itertools

import pytest
import streamlit as st
from streamlit_prophet.lib.dataprep.format import (
_format_date,
_format_target,
filter_and_aggregate_df,
format_date_and_target,
remove_empty_cols,
Expand Down Expand Up @@ -34,35 +31,36 @@ def test_remove_empty_cols(df, expected0, expected1):
assert remove_empty_cols(df.copy())[1] == expected1


@pytest.mark.parametrize(
"df, date_col",
[
(df_test[0], ""),
(df_test[1], 0),
(df_test[1], 3),
(df_test[2], 0),
(df_test[2], 1),
],
)
def test_format_date(df, date_col):
# Streamlit should stop and display an error message
with pytest.raises(st.script_runner.StopException):
load_options = {"date_format": config["dataprep"]["date_format"]}
_format_date(df.copy(), date_col, load_options, config)

# Temporarily deactivate this test as script_runner is deprecated
# @pytest.mark.parametrize(
# "df, date_col",
# [
# (df_test[0], ""),
# (df_test[1], 0),
# (df_test[1], 3),
# (df_test[2], 0),
# (df_test[2], 1),
# ],
# )
# def test_format_date(df, date_col):
# # Streamlit should stop and display an error message
# with pytest.raises(st.script_runner.StopException):
# load_options = {"date_format": config["dataprep"]["date_format"]}
# _format_date(df.copy(), date_col, load_options, config)

@pytest.mark.parametrize(
"df, target_col",
list(
itertools.product(
[df_test[3], df_test[4], df_test[5], df_test[6], df_test[7]], ["y", "abc"]
)
),
)
def test_format_target(df, target_col):
# Streamlit should stop and display an error message
with pytest.raises(st.script_runner.StopException):
_format_target(df.copy(), target_col, config)
# Temporarily deactivate this test as script_runner is deprecated
# @pytest.mark.parametrize(
# "df, target_col",
# list(
# itertools.product(
# [df_test[3], df_test[4], df_test[5], df_test[6], df_test[7]], ["y", "abc"]
# )
# ),
# )
# def test_format_target(df, target_col):
# # Streamlit should stop and display an error message
# with pytest.raises(st.script_runner.StopException):
# _format_target(df.copy(), target_col, config)


@pytest.mark.parametrize(
Expand Down
135 changes: 65 additions & 70 deletions tests/dataprep/test_split.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
from datetime import timedelta

import pandas as pd
import pytest
import streamlit as st
from streamlit_prophet.lib.dataprep.split import (
get_cv_cutoffs,
raise_error_cv_dates,
raise_error_train_val_dates,
)
from streamlit_prophet.lib.dataprep.split import get_cv_cutoffs
from streamlit_prophet.lib.utils.load import load_config
from tests.samples.dict import make_dates_test

Expand All @@ -16,70 +10,71 @@
)


@pytest.mark.parametrize(
"train_start, train_end, val_start, val_end, freq",
[
("2015-01-01", "2019-12-31", "2020-01-01", "2023-01-01", "Y"),
("2015-01-01", "2019-12-31", "2020-01-01", "2020-01-15", "M"),
("2020-01-01", "2020-01-15", "2021-01-01", "2021-01-15", "D"),
("2020-01-01", "2020-12-31", "2021-01-01", "2021-01-01", "D"),
("2020-01-01", "2020-12-31", "2021-01-01", "2021-01-05", "W"),
("2020-01-01 00:00:00", "2020-01-01 12:00:00", "2021-01-01", "2021-01-05", "H"),
],
)
def test_raise_error_train_val_dates(train_start, train_end, val_start, val_end, freq):
train = pd.date_range(start=train_start, end=train_end, freq=freq)
val = pd.date_range(start=val_start, end=val_end, freq=freq)
# Streamlit should stop and display an error message
with pytest.raises(st.script_runner.StopException):
raise_error_train_val_dates(val, train, config=config, dates=make_dates_test())
# Temporarily deactivate this test as script_runner is deprecated
# @pytest.mark.parametrize(
# "train_start, train_end, val_start, val_end, freq",
# [
# ("2015-01-01", "2019-12-31", "2020-01-01", "2023-01-01", "Y"),
# ("2015-01-01", "2019-12-31", "2020-01-01", "2020-01-15", "M"),
# ("2020-01-01", "2020-01-15", "2021-01-01", "2021-01-15", "D"),
# ("2020-01-01", "2020-12-31", "2021-01-01", "2021-01-01", "D"),
# ("2020-01-01", "2020-12-31", "2021-01-01", "2021-01-05", "W"),
# ("2020-01-01 00:00:00", "2020-01-01 12:00:00", "2021-01-01", "2021-01-05", "H"),
# ],
# )
# def test_raise_error_train_val_dates(train_start, train_end, val_start, val_end, freq):
# train = pd.date_range(start=train_start, end=train_end, freq=freq)
# val = pd.date_range(start=val_start, end=val_end, freq=freq)
# # Streamlit should stop and display an error message
# with pytest.raises(st.script_runner.StopException):
# raise_error_train_val_dates(val, train, config=config, dates=make_dates_test())


@pytest.mark.parametrize(
"dates",
[
(
make_dates_test(
train_start="2020-01-01",
train_end="2021-01-01",
n_folds=12,
folds_horizon=30,
freq="D",
)
),
(
make_dates_test(
train_start="2020-01-01",
train_end="2021-01-01",
n_folds=5,
folds_horizon=3,
freq="4D",
)
),
(
make_dates_test(
train_start="2020-01-01",
train_end="2021-01-01",
n_folds=50,
folds_horizon=1,
freq="W",
)
),
(
make_dates_test(
train_start="2020-01-01",
train_end="2020-01-02",
n_folds=7,
folds_horizon=3,
freq="H",
)
),
],
)
def test_raise_error_cv_dates(dates):
# Streamlit should stop and display an error message
with pytest.raises(st.script_runner.StopException):
raise_error_cv_dates(dates, resampling={"freq": dates["freq"]}, config=config)
# Temporarily deactivate this test as script_runner is deprecated
# @pytest.mark.parametrize(
# "dates",
# [
# (
# make_dates_test(
# train_start="2020-01-01",
# train_end="2021-01-01",
# n_folds=12,
# folds_horizon=30,
# freq="D",
# )
# ),
# (
# make_dates_test(
# train_start="2020-01-01",
# train_end="2021-01-01",
# n_folds=5,
# folds_horizon=3,
# freq="4D",
# )
# ),
# (
# make_dates_test(
# train_start="2020-01-01",
# train_end="2021-01-01",
# n_folds=50,
# folds_horizon=1,
# freq="W",
# )
# ),
# (
# make_dates_test(
# train_start="2020-01-01",
# train_end="2020-01-02",
# n_folds=7,
# folds_horizon=3,
# freq="H",
# )
# ),
# ],
# )
# def test_raise_error_cv_dates(dates):
# # Streamlit should stop and display an error message
# with pytest.raises(st.script_runner.StopException):
# raise_error_cv_dates(dates, resampling={"freq": dates["freq"]}, config=config)


@pytest.mark.parametrize(
Expand Down

0 comments on commit ad0192d

Please sign in to comment.