Skip to content

Commit

Permalink
Add the 'hit callback' func (#542)
Browse files Browse the repository at this point in the history
Signed-off-by: SimFG <[email protected]>
  • Loading branch information
SimFG authored Sep 23, 2023
1 parent 15b2fd9 commit 140adb3
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 3 deletions.
5 changes: 4 additions & 1 deletion gptcache/adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,10 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg
cache_answers = sorted(cache_answers, key=lambda x: x[0], reverse=True)
answers_dict = dict((d[1], d) for d in cache_answers)
if len(cache_answers) != 0:

hit_callback = kwargs.pop("hit_callback", None)
if hit_callback and callable(hit_callback):
factor = max_rank - min_rank
hit_callback([(d[3].question, d[0] / factor if factor else d[0]) for d in cache_answers])
def post_process():
if chat_cache.post_process_messages_func is temperature_softmax:
return_message = chat_cache.post_process_messages_func(
Expand Down
2 changes: 1 addition & 1 deletion gptcache/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def import_huggingface_hub():


def import_onnxruntime():
_check_library("onnxruntime")
_check_library("onnxruntime", package="onnxruntime==1.14.1")


def import_faiss():
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/adapter/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_gptcache_api():
put("api-hello2", "foo2", cache_obj=inner_cache)
put("api-hello3", "foo3", cache_obj=inner_cache)

messages = get("hello", cache_obj=inner_cache, top_k=3)
messages = get("hello", cache_obj=inner_cache, top_k=3, hit_callback=lambda x: print("hit_callback", x))
assert len(messages) == 3
assert "foo1" in messages
assert "foo2" in messages
Expand Down

0 comments on commit 140adb3

Please sign in to comment.