Skip to content

Commit

Permalink
nltk
Browse files Browse the repository at this point in the history
  • Loading branch information
Wenshansilvia committed Oct 2, 2024
1 parent 3b2a2e5 commit d15bef9
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
6 changes: 4 additions & 2 deletions rageval/utils/check_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from .prompt import DOC_TO_SENTENCES_PROMPT

logger = logging.getLogger(__name__)
if not Downloader().is_installed('punkt'):
nltk.download('punkt')
#if not Downloader().is_installed('punkt'):
# nltk.download('punkt')
if not Downloader().is_installed('punkt_tab'):
nltk.download('punkt_tab')


def text_to_sents(text: str, model_name="nltk") -> List[str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from rageval.metrics import AnswerNLICorrectness


@pytest.fixture(scope='module')
#@pytest.fixture(scope='module')
def sample():
test_case = {
"answers": [
Expand All @@ -28,7 +28,7 @@ def sample():
return test_case


@pytest.fixture(scope='module')
#@pytest.fixture(scope='module')
def sample_with_decompose():
test_case = {
"answers": [
Expand All @@ -47,19 +47,19 @@ def sample_with_decompose():
return test_case


@pytest.fixture(scope='module')
#@pytest.fixture(scope='module')
def testset(sample):
ds = Dataset.from_dict(sample)
return ds


@pytest.fixture(scope='module')
#@pytest.fixture(scope='module')
def testset_with_decompose(sample_with_decompose):
ds = Dataset.from_dict(sample_with_decompose)
return ds


@pytest.mark.slow
#@pytest.mark.slow
def test_case_on_answer_claim_recall_metric(testset):
nli_model = NLIModel(
'text2text-generation',
Expand All @@ -72,7 +72,7 @@ def test_case_on_answer_claim_recall_metric(testset):
assert score == 0 or score == 1


@pytest.mark.slow
#@pytest.mark.slow
def test_case_on_answer_claim_recall_metric_with_decompose(testset_with_decompose):
nli_model = NLIModel(
'text2text-generation',
Expand All @@ -83,3 +83,5 @@ def test_case_on_answer_claim_recall_metric_with_decompose(testset_with_decompos
assert metric.mtype == 'AnswerCorrectness'
score, results = metric.compute(testset_with_decompose['answers'], testset_with_decompose['gt_answers'], 1)
assert score == 0 or score == 1

test_case_on_answer_claim_recall_metric_with_decompose(testset_with_decompose(sample_with_decompose()))

0 comments on commit d15bef9

Please sign in to comment.