Skip to content

Commit

Permalink
functions
Browse files Browse the repository at this point in the history
  • Loading branch information
zainhoda committed Jun 4, 2024
1 parent 84ab387 commit 6c3c7c3
Show file tree
Hide file tree
Showing 4 changed files with 435 additions and 24 deletions.
16 changes: 16 additions & 0 deletions src/vanna/advanced/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from abc import ABC, abstractmethod


class VannaAdvanced(ABC):
def __init__(self, vanna_kb: str, vanna_api_key: str, config=None):
self.vanna_kb = vanna_kb
self.vanna_api_key = vanna_api_key
self.config = config

@abstractmethod
def get_function(self, question: str, additional_data: dict = {}) -> dict:
pass

@abstractmethod
def create_function(self, question: str, sql: str, plotly_code: str, **kwargs) -> dict:
pass
131 changes: 124 additions & 7 deletions src/vanna/flask/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import json
import logging
import os
import sys
import uuid
from abc import ABC, abstractmethod
from functools import wraps

import flask
import requests
from flask import Flask, Response, jsonify, request
from flask import Flask, Response, jsonify, request, send_from_directory
from flask_sock import Sock

from .assets import css_content, html_content, js_content
Expand Down Expand Up @@ -151,7 +152,10 @@ def __init__(self, vn, cache: Cache = MemoryCache(),
auto_fix_sql=True,
ask_results_correct=True,
followup_questions=True,
summarization=True
summarization=True,
function_generation=True,
index_html_path=None,
assets_folder=None,
):
"""
Expose a Flask app that can be used to interact with a Vanna instance.
Expand All @@ -176,6 +180,8 @@ def __init__(self, vn, cache: Cache = MemoryCache(),
ask_results_correct: Whether to ask the user if the results are correct. Defaults to True.
followup_questions: Whether to show followup questions. Defaults to True.
summarization: Whether to show summarization. Defaults to True.
index_html_path: Path to the index.html. Defaults to None, which will use the default index.html
assets_folder: The location where you'd like to serve the static assets from. Defaults to None, which will use hardcoded Python variables.
Returns:
None
Expand All @@ -202,6 +208,9 @@ def __init__(self, vn, cache: Cache = MemoryCache(),
self.ask_results_correct = ask_results_correct
self.followup_questions = followup_questions
self.summarization = summarization
self.function_generation = function_generation and hasattr(vn, "get_function")
self.index_html_path = index_html_path
self.assets_folder = assets_folder

log = logging.getLogger("werkzeug")
log.setLevel(logging.ERROR)
Expand Down Expand Up @@ -247,6 +256,7 @@ def get_config(user: any):
"ask_results_correct": self.ask_results_correct,
"followup_questions": self.followup_questions,
"summarization": self.summarization,
"function_generation": self.function_generation,
}

config = self.auth.override_config_for_user(user, config)
Expand Down Expand Up @@ -345,6 +355,56 @@ def generate_sql(user: any):
}
)

@self.flask_app.route("/api/v0/get_function", methods=["GET"])
@self.requires_auth
def get_function(user: any):
question = flask.request.args.get("question")

if question is None:
return jsonify({"type": "error", "error": "No question provided"})

if not hasattr(vn, "get_function"):
return jsonify({"type": "error", "error": "This setup does not support function generation."})

id = self.cache.generate_id(question=question)
function = vn.get_function(question=question)

if function is None:
return jsonify({"type": "error", "error": "No function found"})

if 'instantiated_sql' not in function:
self.vn.log(f"No instantiated SQL found for {question} in {function}")
return jsonify({"type": "error", "error": "No instantiated SQL found"})

self.cache.set(id=id, field="question", value=question)
self.cache.set(id=id, field="sql", value=function['instantiated_sql'])

if 'instantiated_post_processing_code' in function and function['instantiated_post_processing_code'] is not None and len(function['instantiated_post_processing_code']) > 0:
self.cache.set(id=id, field="plotly_code", value=function['instantiated_post_processing_code'])

return jsonify(
{
"type": "function",
"id": id,
"function": function,
}
)

@self.flask_app.route("/api/v0/get_all_functions", methods=["GET"])
@self.requires_auth
def get_all_functions(user: any):
if not hasattr(vn, "get_all_functions"):
return jsonify({"type": "error", "error": "This setup does not support function generation."})

functions = vn.get_all_functions()

return jsonify(
{
"type": "functions",
"functions": functions,
}
)

@self.flask_app.route("/api/v0/run_sql", methods=["GET"])
@self.requires_auth
@self.requires_cache(["sql"])
Expand Down Expand Up @@ -438,11 +498,18 @@ def generate_plotly_figure(user: any, id: str, df, question, sql):
question = f"{question}. When generating the chart, use these special instructions: {chart_instructions}"

try:
code = vn.generate_plotly_code(
question=question,
sql=sql,
df_metadata=f"Running df.dtypes gives:\n {df.dtypes}",
)
# If chart_instructions is not set then attempt to retrieve the code from the cache
if chart_instructions is None or len(chart_instructions) == 0:
code = self.cache.get(id=id, field="plotly_code")

if code is None:
code = vn.generate_plotly_code(
question=question,
sql=sql,
df_metadata=f"Running df.dtypes gives:\n {df.dtypes}",
)
self.cache.set(id=id, field="plotly_code", value=code)

fig = vn.get_plotly_figure(plotly_code=code, df=df, dark_mode=False)
fig_json = fig.to_json()

Expand Down Expand Up @@ -518,6 +585,49 @@ def add_training_data(user: any):
print("TRAINING ERROR", e)
return jsonify({"type": "error", "error": str(e)})

@self.flask_app.route("/api/v0/create_function", methods=["POST"])
@self.requires_auth
def create_function(user: any):
question = flask.request.json.get("question")
sql = flask.request.json.get("sql")
id = flask.request.json.get("id")

plotly_code = self.cache.get(id=id, field="plotly_code")

if plotly_code is None:
plotly_code = ""

function_data = self.vn.create_function(question=question, sql=sql, plotly_code=plotly_code)

return jsonify(
{
"type": "function_template",
"id": id,
"function_template": function_data,
}
)

@self.flask_app.route("/api/v0/update_function", methods=["POST"])
@self.requires_auth
def update_function(user: any):
old_function_name = flask.request.json.get("old_function_name")
updated_function = flask.request.json.get("updated_function")

print("old_function_name", old_function_name)
print("updated_function", updated_function)

updated = vn.update_function(old_function_name=old_function_name, updated_function=updated_function)

return jsonify({"success": updated})

@self.flask_app.route("/api/v0/delete_function", methods=["POST"])
@self.requires_auth
def delete_function(user: any):
function_name = flask.request.json.get("function_name")

return jsonify({"success": vn.delete_function(function_name=function_name)})


@self.flask_app.route("/api/v0/generate_followup_questions", methods=["GET"])
@self.requires_auth
@self.requires_cache(["df", "question", "sql"])
Expand Down Expand Up @@ -616,6 +726,9 @@ def catch_all(catch_all):

@self.flask_app.route("/assets/<path:filename>")
def proxy_assets(filename):
if self.assets_folder:
return send_from_directory(self.assets_folder, filename)

if ".css" in filename:
return Response(css_content, mimetype="text/css")

Expand Down Expand Up @@ -663,6 +776,10 @@ def sock_log(ws):
@self.flask_app.route("/", defaults={"path": ""})
@self.flask_app.route("/<path:path>")
def hello(path: str):
if self.index_html_path:
directory = os.path.dirname(self.index_html_path)
filename = os.path.basename(self.index_html_path)
return send_from_directory(directory=directory, path=filename)
return html_content

def run(self, *args, **kwargs):
Expand Down
52 changes: 36 additions & 16 deletions src/vanna/flask/assets.py

Large diffs are not rendered by default.

Loading

0 comments on commit 6c3c7c3

Please sign in to comment.