-
Notifications
You must be signed in to change notification settings - Fork 27
/
app_streamlit.py
165 lines (130 loc) · 4.85 KB
/
app_streamlit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import base64
from collections import OrderedDict
import requests
import streamlit as st
import streamlit.components.v1 as components
from PIL import Image
headers = {
'Content-Type': 'application/json; charset=utf-8'
}
# Set default endpoint. Usually this would be passed to a function via a parameter
DEFAULT_ENDPOINT = "http://0.0.0.0:45678/search"
class text:
"""
Jina text search
"""
class process:
"""
Process query and results
"""
def json(query: str, top_k: int, endpoint: str) -> list:
"""
Process Jina's JSON output and parse results
"""
data = {
"data": [
{
"text": query
}
],
"targetExecutor": "",
"parameters": {"limit": top_k}
}
response = requests.post(endpoint, headers=headers, json=data)
content = response.json()["data"]
results = []
for doc in content:
matches = doc["matches"] # list
for match in matches:
if match["parent_id"] != None:
continue
score = match['scores']['cosine']['value']
title = match["tags"]["title"]
question = match["tags"]["question"]
answer = match["tags"]["answer"]
results.append(OrderedDict({'base_score': score,
'title': title,
'question': question,
'answer': answer}))
return results
class image:
"""
Jina image search
"""
class encode:
"""
Encode image to base64 and return JSON string
"""
def img_base64(byte_string):
"""
Encode image file to base64
"""
output = str(base64.b64encode(byte_string))[2:-1]
output = f'["data:image/png;base64,{output}"]'
return output
class process:
def json(query: str, top_k: int, endpoint: str) -> list:
data = (
'{"top_k":' + str(top_k) + ', "mode": "search", "data":' + query + "}"
)
response = requests.post(endpoint, headers=headers, data=data)
content = response.json()["search"]["docs"]
results = []
for doc in content:
matches = doc["matches"]
for match in matches:
results.append(match["uri"])
return results
class render:
"""
Render image output
"""
def html(results: list) -> str:
"""
Render images as list of HTML img tags
"""
output = ""
for doc in results:
html = f'<img src="{doc}">'
output += html
return output
class jina:
def text_search(endpoint=DEFAULT_ENDPOINT, top_k=10, hidden=[]):
container = st.container()
with container:
if "endpoint" not in hidden:
endpoint = st.text_input("Endpoint", endpoint)
query = st.text_input("Enter query", value='다른 사람의 땅에 나무를 심었는데 누구 소유인가요?')
print(f'Query: {query}')
if "top_k" not in hidden:
top_k = st.slider("Results", 1, top_k, int(top_k / 2))
button = st.button("Search")
if button:
matches = text.process.json(query, top_k, endpoint)
st.write(matches)
print()
return container
def image_search(endpoint=DEFAULT_ENDPOINT, top_k=10, hidden=[]):
container = st.container()
with container:
if "endpoint" not in hidden:
endpoint = st.text_input("Endpoint", endpoint)
query = st.file_uploader("Upload file")
if "top_k" not in hidden:
top_k = st.slider("Results", 1, top_k, int(top_k / 2))
button = st.button("Search")
if button:
# encode to base64 and embed in json
encoded_query = image.encode.img_base64(query.read())
# post to REST API and process response
matches = image.process.json(encoded_query, top_k, endpoint)
# convert list of matches to html
output = image.render.html(matches)
# render html
return st.markdown(output, unsafe_allow_html=True)
return container
st.set_page_config(page_title="Jina Text Search",)
endpoint = "http://0.0.0.0:1234/search"
st.title("LegalQA with SentenceKoBART")
st.markdown("")
jina.text_search(endpoint=endpoint, hidden=['endpoint'])