Skip to content

Commit

Permalink
functionrag
Browse files Browse the repository at this point in the history
  • Loading branch information
zainhoda committed Jun 7, 2024
1 parent 6c3c7c3 commit 8c66292
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 122 deletions.
16 changes: 13 additions & 3 deletions src/vanna/advanced/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@


class VannaAdvanced(ABC):
def __init__(self, vanna_kb: str, vanna_api_key: str, config=None):
self.vanna_kb = vanna_kb
self.vanna_api_key = vanna_api_key
def __init__(self, config=None):
self.config = config

@abstractmethod
Expand All @@ -14,3 +12,15 @@ def get_function(self, question: str, additional_data: dict = {}) -> dict:
@abstractmethod
def create_function(self, question: str, sql: str, plotly_code: str, **kwargs) -> dict:
pass

@abstractmethod
def update_function(self, old_function_name: str, updated_function: dict) -> bool:
pass

@abstractmethod
def delete_function(self, function_name: str) -> bool:
pass

@abstractmethod
def get_all_functions(self) -> list:
pass
2 changes: 1 addition & 1 deletion src/vanna/flask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,4 +809,4 @@ def run(self, *args, **kwargs):
print("Your app is running at:")
print("http://localhost:8084")

self.flask_app.run(host="0.0.0.0", port=8084, debug=self.debug)
self.flask_app.run(host="0.0.0.0", port=8084, debug=self.debug, use_reloader=False)
30 changes: 15 additions & 15 deletions src/vanna/flask/assets.py

Large diffs are not rendered by default.

126 changes: 23 additions & 103 deletions src/vanna/vannadb/vannadb_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@ def __init__(self, vanna_model: str, vanna_api_key: str, config=None):
else config["endpoint"]
)
self.related_training_data = {}
self._graphql_endpoint = "http://localhost:8080/query"
self._graphql_endpoint = "https://functionrag.com/query"
self._graphql_headers = {
"Content-Type": "application/json",
"API-KEY": self._api_key,
"NAMESPACE": self._model,
}

def _rpc_call(self, method, params):
if method != "list_orgs":
Expand Down Expand Up @@ -62,43 +67,8 @@ def _dataclass_to_dict(self, obj):
return dataclasses.asdict(obj)

def get_all_functions(self) -> list:
# return [
# {
# "function_name": "calculate_average",
# "description": "Calculates the average value for a specified column in a table.",
# "arguments": [
# {
# "name": "column",
# "general_type": "STRING",
# "is_user_editable": True,
# "available_values": None
# },
# {
# "name": "table",
# "general_type": "STRING",
# "is_user_editable": True,
# "available_values": None
# }
# ],
# "sql_template": "SELECT AVG({column}) FROM {table};"
# },
# {
# "function_name": "get_top_customers_by_sales",
# "description": "Get the top customers by sales",
# "arguments": [
# {
# "name": "number_of_customers",
# "general_type": "NUMERIC",
# "is_user_editable": True,
# "available_values": None
# }
# ],
# "sql_template": "SELECT * FROM customers ORDER BY sales DESC LIMIT {number_of_customers};"
# }
# ]

query = """
{
{
get_all_sql_functions {
function_name
description
Expand All @@ -112,10 +82,10 @@ def get_all_functions(self) -> list:
}
sql_template
}
}
}
"""

response = requests.post(self._graphql_endpoint, json={'query': query})
response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': query})
response_json = response.json()
if response.status_code == 200 and 'data' in response_json and 'get_all_sql_functions' in response_json['data']:
self.log(response_json['data']['get_all_sql_functions'])
Expand All @@ -128,14 +98,10 @@ def get_all_functions(self) -> list:
raise Exception(f"Query failed to run by returning code of {response.status_code}. {response.text}")




def get_function(self, question: str, additional_data: dict = {}) -> dict:
# return {'function_name': 'get_artist_with_most_albums', 'description': 'What is the name of the artist that has the most albums?', 'arguments': [], 'sql_template': 'SELECT a.Name, COUNT(al.AlbumId) AS AlbumCount FROM Artist a JOIN Album al ON a.ArtistId = al.ArtistId GROUP BY a.Name ORDER BY AlbumCount DESC LIMIT 1;', 'instantiated_sql': 'SELECT a.Name, COUNT(al.AlbumId) AS AlbumCount FROM Artist a JOIN Album al ON a.ArtistId = al.ArtistId GROUP BY a.Name ORDER BY AlbumCount DESC LIMIT 1;'}

query = """
query GetFunction($question: String!) {
get_function(question: $question) {
query GetFunction($question: String!, $staticFunctionArguments: [StaticFunctionArgument]) {
get_and_instantiate_function(question: $question, static_function_arguments: $staticFunctionArguments) {
... on SQLFunction {
function_name
description
Expand All @@ -155,53 +121,24 @@ def get_function(self, question: str, additional_data: dict = {}) -> dict:
}
}
"""
variables = {"question": question}
response = requests.post(self._graphql_endpoint, json={'query': query, 'variables': variables})
static_function_arguments = [{"name": key, "value": str(value)} for key, value in additional_data.items()]
variables = {"question": question, "staticFunctionArguments": static_function_arguments}
response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': query, 'variables': variables})
response_json = response.json()
if response.status_code == 200 and 'data' in response_json and 'get_function' in response_json['data']:
self.log(response_json['data']['get_function'])
resp = response_json['data']['get_function']
if response.status_code == 200 and 'data' in response_json and 'get_and_instantiate_function' in response_json['data']:
self.log(response_json['data']['get_and_instantiate_function'])
resp = response_json['data']['get_and_instantiate_function']

print(resp)

return resp
else:
raise Exception(f"Query failed to run by returning code of {response.status_code}. {response.text}")


# params = [Question(question=question)]

# # For now this is just a mock
# mock_function_return = {
# "function_name": "example_function",
# "arguments": [
# {
# "name": "arg1",
# "general_type": "String",
# "is_user_editable": True,
# "instantiated_value": "value1",
# "available_values": ["value1", "value2", "value3"]
# },
# {
# "name": "arg2",
# "general_type": "Integer",
# "is_user_editable": False,
# "instantiated_value": "10",
# "available_values": []
# }
# ],
# "sql_template": "SELECT * FROM table WHERE column1 = {{arg1}} AND column2 = {{arg2}};",
# "instantiated_sql": "SELECT COUNT(*) FROM Artist"
# }

return mock_function_return

def create_function(self, question: str, sql: str, plotly_code: str, **kwargs) -> dict:
# return {'function_name': 'get_top_artists_by_sales', 'description': 'Who are the top {limit} artists by sales?', 'arguments': [{'name': 'limit', 'description': 'The number of top artists to retrieve based on sales', 'general_type': 'STRING', 'is_user_editable': True}], 'sql_template': 'SELECT a.Name, SUM(i.UnitPrice * i.Quantity) AS TotalSales FROM Artist a JOIN Album al ON a.ArtistId = al.ArtistId JOIN Track t ON al.AlbumId = t.AlbumId JOIN InvoiceLine i ON t.TrackId = i.TrackId GROUP BY a.Name ORDER BY TotalSales DESC LIMIT {limit};'}

query = """
mutation CreateFunction($question: String!, $sql: String!, $plotly_code: String!) {
create_sql_function(question: $question, sql: $sql, post_processing_code: $plotly_code) {
generate_and_create_sql_function(question: $question, sql: $sql, post_processing_code: $plotly_code) {
function_name
description
arguments {
Expand All @@ -216,10 +153,10 @@ def create_function(self, question: str, sql: str, plotly_code: str, **kwargs) -
}
"""
variables = {"question": question, "sql": sql, "plotly_code": plotly_code}
response = requests.post(self._graphql_endpoint, json={'query': query, 'variables': variables})
response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': query, 'variables': variables})
response_json = response.json()
if response.status_code == 200 and 'data' in response_json and response_json['data'] is not None and 'create_sql_function' in response_json['data']:
resp = response_json['data']['create_sql_function']
if response.status_code == 200 and 'data' in response_json and response_json['data'] is not None and 'generate_and_create_sql_function' in response_json['data']:
resp = response_json['data']['generate_and_create_sql_function']

print(resp)

Expand Down Expand Up @@ -249,23 +186,6 @@ def update_function(self, old_function_name: str, updated_function: dict) -> boo
}
"""

# input SQLFunctionUpdate {
# old_function_name: String!
# function_name: String!
# description: String!
# arguments: [ArgumentUpdate]!
# sql_template: String!
# post_processing_code_template: String!
# }

# input ArgumentUpdate {
# name: String!
# general_type: GeneralType!
# description: String!
# is_user_editable: Boolean!
# available_values: [String]
# }

SQLFunctionUpdate = {
'function_name', 'description', 'arguments', 'sql_template', 'post_processing_code_template'
}
Expand Down Expand Up @@ -296,7 +216,7 @@ def validate_arguments(args):

print("variables", variables)

response = requests.post(self._graphql_endpoint, json={'query': mutation, 'variables': variables})
response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': mutation, 'variables': variables})
response_json = response.json()
if response.status_code == 200 and 'data' in response_json and response_json['data'] is not None and 'update_sql_function' in response_json['data']:
return response_json['data']['update_sql_function']
Expand All @@ -310,7 +230,7 @@ def delete_function(self, function_name: str) -> bool:
}
"""
variables = {"function_name": function_name}
response = requests.post(self._graphql_endpoint, json={'query': mutation, 'variables': variables})
response = requests.post(self._graphql_endpoint, headers=self._graphql_headers, json={'query': mutation, 'variables': variables})
response_json = response.json()
if response.status_code == 200 and 'data' in response_json and response_json['data'] is not None and 'delete_sql_function' in response_json['data']:
return response_json['data']['delete_sql_function']
Expand Down

0 comments on commit 8c66292

Please sign in to comment.