diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index 0299fe81..e136201d 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -79,6 +79,7 @@ def __init__(self, config=None): self.static_documentation = "" self.dialect = self.config.get("dialect", "SQL") self.language = self.config.get("language", None) + self.max_tokens = self.config.get("max_tokens", 14000) def log(self, message: str, title: str = "Info"): print(message) @@ -559,14 +560,14 @@ def get_sql_prompt( "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. " initial_prompt = self.add_ddl_to_prompt( - initial_prompt, ddl_list, max_tokens=14000 + initial_prompt, ddl_list, max_tokens=self.max_tokens ) if self.static_documentation != "": doc_list.append(self.static_documentation) initial_prompt = self.add_documentation_to_prompt( - initial_prompt, doc_list, max_tokens=14000 + initial_prompt, doc_list, max_tokens=self.max_tokens ) initial_prompt += ( @@ -603,15 +604,15 @@ def get_followup_questions_prompt( initial_prompt = f"The user initially asked the question: '{question}': \n\n" initial_prompt = self.add_ddl_to_prompt( - initial_prompt, ddl_list, max_tokens=14000 + initial_prompt, ddl_list, max_tokens=self.max_tokens ) initial_prompt = self.add_documentation_to_prompt( - initial_prompt, doc_list, max_tokens=14000 + initial_prompt, doc_list, max_tokens=self.max_tokens ) initial_prompt = self.add_sql_to_prompt( - initial_prompt, question_sql_list, max_tokens=14000 + initial_prompt, question_sql_list, max_tokens=self.max_tokens ) message_log = [self.system_message(initial_prompt)]