Skip to content

Commit

Permalink
Add input summarization. (#404)
Browse files Browse the repository at this point in the history
Signed-off-by: wxywb <[email protected]>
  • Loading branch information
wxywb authored Jun 2, 2023
1 parent fd7e303 commit 18bfb1d
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 12 deletions.
17 changes: 17 additions & 0 deletions gptcache/adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg
context = kwargs.pop("cache_context", {})
embedding_data = None
# you want to retry to send the request to chatgpt when the cache is negative

if 0 < temperature < 2:
cache_skip_options = [True, False]
prob_cache_skip = [0, 1]
Expand Down Expand Up @@ -53,6 +54,9 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg
pre_store_data = pre_embedding_res
pre_embedding_data = pre_embedding_res

if chat_cache.config.input_summary_len is not None:
pre_embedding_data = summarize_input(pre_embedding_data, chat_cache.config.input_summary_len)

if cache_enable:
embedding_data = time_cal(
chat_cache.embedding_func,
Expand Down Expand Up @@ -213,3 +217,16 @@ def update_cache_func(handled_llm_data, question=None):
except Exception as e: # pylint: disable=W0703
gptcache_log.warning("failed to save the data to cache, error: %s", e)
return llm_data


input_summarizer = None

def summarize_input(text, text_length):
# pylint: disable=import-outside-toplevel
from gptcache.processor.context.summarization_context import SummarizationContextProcess
global input_summarizer
if input_summarizer is None:
input_summarizer = SummarizationContextProcess()
summarization = input_summarizer.summarize_to_sentence([text], text_length)
return summarization

4 changes: 4 additions & 0 deletions gptcache/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class Config:
:type auto_flush: int
:param enable_token_counter: enable token counter, default to False
:type enable_token_counter: bool
:param input_summary_len: optional, summarize input to specified length.
:type input_summary_len: Optional[int]
Example:
.. code-block:: python
Expand All @@ -36,6 +38,7 @@ def __init__(
template: Optional[str] = None,
auto_flush: int = 20,
enable_token_counter: bool = True,
input_summary_len: Optional[int] = None
):
if similarity_threshold < 0 or similarity_threshold > 1:
raise CacheError(
Expand All @@ -47,3 +50,4 @@ def __init__(
self.template = template
self.auto_flush = auto_flush
self.enable_token_counter = enable_token_counter
self.input_summary_len = input_summary_len
4 changes: 2 additions & 2 deletions gptcache/processor/context/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
]


def SummarizationContextProcess(summarizer=None, tokenizer=None, target_length=512):
def SummarizationContextProcess(model_name=None, tokenizer=None, target_length=512):
return summarization.SummarizationContextProcess(
summarizer, tokenizer, target_length
model_name, tokenizer, target_length
)


Expand Down
29 changes: 22 additions & 7 deletions gptcache/processor/context/summarization_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,22 @@

import transformers # pylint: disable=C0413

def summarize_to_length(summarizer, text, target_len, max_len=1024):
tokenizer = summarizer.tokenizer
def token_length(text):
return len(tokenizer.encode(text))
segment_len = max_len - 100
summary_result = text
while token_length(text) > target_len:
tokens = tokenizer.encode(text)
segments = [tokens[i:i+segment_len] for i in range(0, len(tokens), segment_len-1)]
summary_result = ""
for segment in segments:
len_seg = int(len(segment)/4)
summary = summarizer(tokenizer.decode(segment), min_length=max(len_seg-10, 1), max_length=len_seg)
summary_result += summary[0]["summary_text"]
text = summary_result
return summary_result

class SummarizationContextProcess(ContextProcess):
"""A context processor for summarizing large amounts of text data using a summarizer model.
Expand All @@ -29,18 +45,17 @@ class SummarizationContextProcess(ContextProcess):
context_process = SummarizationContextProcess()
cache.init(pre_embedding_func=context_process.pre_process)
"""
def __init__(self, summarizer=transformers.pipeline("summarization", model="facebook/bart-large-cnn"),
def __init__(self, model_name="facebook/bart-large-cnn",
tokenizer=None, target_length=512):
if not summarizer:
summarizer = transformers.pipeline("summarization", model="facebook/bart-large-cnn")
summarizer = transformers.pipeline(task="summarization", model=model_name)
self.summarizer = summarizer
self.target_length = target_length
if tokenizer is None:
tokenizer = transformers.RobertaTokenizer.from_pretrained("roberta-base")
self.tokenizer = tokenizer
self.content = ""

def summarize_to_sentence(self, summarizer, sentences, target_size=1000):
def summarize_to_sentence(self, sentences, target_size=1000):
lengths = []
for sentence in sentences:
lengths.append(len(sentence))
Expand All @@ -49,8 +64,8 @@ def summarize_to_sentence(self, summarizer, sentences, target_size=1000):
target_sentences = []
for sent, target_len in zip(sentences, target_lengths):
if len(self.tokenizer.tokenize(sent)) > target_len:
response = summarizer(sent, max_length=target_len, min_length=1, do_sample=False)
target_sentence = response[0]["summary_text"]
response = summarize_to_length(self.summarizer, sent, target_len, self.tokenizer.model_max_length)
target_sentence = response
else:
target_sentence = sent
target_sentences.append(target_sentence)
Expand All @@ -71,7 +86,7 @@ def serialize_content(content):
for message in content:
ret += "[#RS]{}[#RE][#CS]{}[#CE]".format(message["role"], message["content"])
return ret
result = self.summarize_to_sentence(self.summarizer, [message["content"] for message in self.content], self.target_length)
result = self.summarize_to_sentence([message["content"] for message in self.content], self.target_length)
save_content = serialize_content(self.content)
embedding_content = result
return save_content, embedding_content
6 changes: 5 additions & 1 deletion tests/unit_tests/adapter/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from gptcache.manager import get_data_manager, manager_factory
from gptcache.utils.error import NotInitError
from gptcache.adapter.adapter import adapt
from gptcache.adapter.adapter import adapt, summarize_input
from gptcache.adapter.api import put, get
from gptcache.processor.pre import get_prompt
from gptcache.processor.post import first, nop
Expand Down Expand Up @@ -156,6 +156,10 @@ def test_cache_temperature():
answers = get(prompt=prompt)
assert len(answers) == 2

def test_input_summarization():
text = "A large language model (LLM) is a language model consisting of a neural network with many parameters (typically billions of weights or more), trained on large quantities of unlabeled text using self-supervised learning or semi-supervised learning. LLMs emerged around 2018 and perform well at a wide variety of tasks. This has shifted the focus of natural language processing research away from the previous paradigm of training specialized supervised models for specific tasks."
summary = summarize_input(text, 40)
assert len(summary.split()) < 40

if __name__ == "__main__":
test_cache_temperature()
4 changes: 2 additions & 2 deletions tests/unit_tests/processor/test_summarize_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

@pytest.mark.tags("L2")
def test_summarization_context_process():
summarizer = pipeline("summarization", model="ainize/bart-base-cnn")
context_process = _get_pre_context_function("summarization", kws={"summarizer": summarizer, "target_length": 512})
#summarizer = pipeline(task="summarization", model="ainize/bart-base-cnn")
context_process = _get_pre_context_function("summarization", kws={"model_name": "facebook/bart-large-cnn", "target_length": 512})
chat = []
chat.append(
{
Expand Down

0 comments on commit 18bfb1d

Please sign in to comment.