Skip to content

Commit

Permalink
simplify exp ut (microsoft#186)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangxu0307 authored Jan 31, 2024
1 parent 9ab8890 commit ec23f76
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ experience_text: 'User Query: show top 3 data in ./demo_data.csv
read it.
2. Update the file path in the code when the user provides a new one.'
raw_experience_path: D:\python_project\TaskWeaver\tests\unit_tests\data\experience\raw_exp_test-exp-1.yaml
raw_experience_path: D:\TaskWeaver\tests\unit_tests\data\experience\raw_exp_test-exp-1.yaml
embedding_model: all-mpnet-base-v2
embedding:
- 0.07244912534952164
Expand Down
69 changes: 0 additions & 69 deletions tests/unit_tests/test_experience.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os

import pytest
import yaml
from injector import Injector

from taskweaver.config.config_mgt import AppConfigSource
Expand All @@ -11,72 +10,10 @@
IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"


@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Test doesn't work in Github Actions.")
def test_experience_generation():
app_injector = Injector([LoggingModule])
app_config = AppConfigSource(
# need to configure llm related config to run this test
# please refer to the https://microsoft.github.io/TaskWeaver/docs/llms
config_file_path=os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"..",
"..",
"project/taskweaver_config.json",
),
config={
"llm.embedding_api_type": "sentence_transformers",
"llm.embedding_model": "all-mpnet-base-v2",
"experience.experience_dir": os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"data/experience",
),
},
)
app_injector.binder.bind(AppConfigSource, to=app_config)
experience_manager = app_injector.create_object(ExperienceGenerator)

experience_manager.refresh(target_role="Planner")
experience_manager.load_experience(target_role="Planner")

exp_files = os.listdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), "data/experience"))
assert len(exp_files) == 2
assert "Planner_exp_test-exp-1.yaml" in exp_files

assert len(experience_manager.experience_list) == 1
exp = experience_manager.experience_list[0]
assert len(exp.experience_text) > 0
assert exp.exp_id == "test-exp-1"
assert len(exp.embedding) == 768
assert exp.raw_experience_path == os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"data",
"experience",
"raw_exp_test-exp-1.yaml",
)
assert exp.embedding_model == "all-mpnet-base-v2"

with open(
os.path.join(os.path.dirname(os.path.abspath(__file__)), "data/experience/Planner_exp_test-exp-1.yaml"),
) as f:
exp = yaml.safe_load(f)
assert "experience_text" in exp
assert exp["exp_id"] == "test-exp-1"
assert len(exp["embedding"]) == 768
assert exp["raw_experience_path"] == os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"data",
"experience",
"raw_exp_test-exp-1.yaml",
)
assert exp["embedding_model"] == "all-mpnet-base-v2"


@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Test doesn't work in Github Actions.")
def test_experience_retrieval():
app_injector = Injector([LoggingModule])
app_config = AppConfigSource(
# need to configure llm related config to run this test
# please refer to the https://microsoft.github.io/TaskWeaver/docs/llms
config_file_path=os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"..",
Expand Down Expand Up @@ -107,12 +44,6 @@ def test_experience_retrieval():
assert len(exp.experience_text) > 0
assert exp.exp_id == "test-exp-1"
assert len(exp.embedding) == 768
assert exp.raw_experience_path == os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"data",
"experience",
"raw_exp_test-exp-1.yaml",
)
assert exp.embedding_model == "all-mpnet-base-v2"

experiences = experience_manager.retrieve_experience(user_query=user_query)
Expand Down

0 comments on commit ec23f76

Please sign in to comment.