Skip to content

Commit

Permalink
Support selecting a model by url params (lm-sys#81)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Mar 30, 2023
1 parent 99157ce commit df8132d
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 12 deletions.
2 changes: 1 addition & 1 deletion fastchat/serve/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def main(args):
outputs += sep
index = outputs.index(sep, len(prompt))

outputs = outputs[len(prompt) + 2:index].strip()
outputs = outputs[len(prompt) + 1:index].strip()
print(f"{conv.roles[1]}: {outputs}")
conv.messages[-1][-1] = outputs

Expand Down
51 changes: 40 additions & 11 deletions fastchat/serve/gradio_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,37 @@ def get_model_list():
return models


def load_demo(request: gr.Request):
logger.info(f"load demo: {request.client.host}")
get_window_url_params = """
function() {
const params = new URLSearchParams(window.location.search);
url_params = Object.fromEntries(params);
console.log(url_params);
return url_params;
}
"""


def load_demo(url_params, request: gr.Request):
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")

dropdown_update = gr.Dropdown.update(visible=True)
if "model" in url_params:
model = url_params["model"]
if model in models:
dropdown_update = gr.Dropdown.update(
value=model, visible=True)

state = default_conversation.copy()
return (state,
gr.Dropdown.update(visible=True),
dropdown_update,
gr.Chatbot.update(visible=True),
gr.Textbox.update(visible=True),
gr.Row.update(visible=True),
gr.Accordion.update(visible=True))


def load_demo_refresh_model_list(request: gr.Request):
logger.info(f"load demo: {request.client.host}")
logger.info(f"load_demo. ip: {request.client.host}")
models = get_model_list()
state = default_conversation.copy()
return (state, gr.Dropdown.update(
Expand All @@ -71,7 +89,6 @@ def load_demo_refresh_model_list(request: gr.Request):


def vote_last_response(state, vote_type, model_selector, request: gr.Request):
logger.info(f"vote_type: {vote_type}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(time.time(), 4),
Expand All @@ -84,32 +101,38 @@ def vote_last_response(state, vote_type, model_selector, request: gr.Request):


def upvote_last_response(state, model_selector, request: gr.Request):
logger.info(f"upvote. ip: {request.client.host}")
vote_last_response(state, "upvote", model_selector, request)
return (disable_btn,) * 3


def downvote_last_response(state, model_selector, request: gr.Request):
logger.info(f"downvote. ip: {request.client.host}")
vote_last_response(state, "downvote", model_selector, request)
return (disable_btn,) * 3


def flag_last_response(state, model_selector, request: gr.Request):
logger.info(f"flag. ip: {request.client.host}")
vote_last_response(state, "flag", model_selector, request)
return (disable_btn,) * 3


def regenerate(state):
def regenerate(state, request: gr.Request):
logger.info(f"regenerate. ip: {request.client.host}")
state.messages[-1][-1] = None
state.skip_next = False
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5


def clear_history():
def clear_history(request: gr.Request):
logger.info(f"clear_history. ip: {request.client.host}")
state = default_conversation.copy()
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5


def add_text(state, text, request: gr.Request):
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
if len(text) <= 0:
state.skip_next = True
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
Expand Down Expand Up @@ -139,6 +162,7 @@ def post_process_code(code):


def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Request):
logger.info(f"http_bot. ip: {request.client.host}")
start_tstamp = time.time()
model_name = model_selector

Expand Down Expand Up @@ -245,6 +269,9 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req


learn_more_markdown = ("""
### Evaluation Samples
The online demo has limited capacity. If you are waiting in the queue, feel free to explore some sample outputs of the models on our benchmark questions by visiting a static website [here](https://vicuna.lmsys.org/eval).
### License
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
""")
Expand All @@ -262,8 +289,6 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req


def build_demo():
models = get_model_list()

with gr.Blocks(title="FastChat", theme=gr.themes.Base(), css=css) as demo:
state = gr.State()

Expand Down Expand Up @@ -294,6 +319,7 @@ def build_demo():
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)

gr.Markdown(learn_more_markdown)
url_params = gr.JSON(visible=False)

# Register listeners
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
Expand All @@ -314,8 +340,9 @@ def build_demo():
[state, chatbot] + btn_list)

if args.model_list_mode == "once":
demo.load(load_demo, None, [state, model_selector,
chatbot, textbox, button_row, parameter_row])
demo.load(load_demo, [url_params], [state, model_selector,
chatbot, textbox, button_row, parameter_row],
_js=get_window_url_params)
elif args.model_list_mode == "reload":
demo.load(load_demo_refresh_model_list, None, [state, model_selector,
chatbot, textbox, button_row, parameter_row])
Expand All @@ -337,6 +364,8 @@ def build_demo():
parser.add_argument("--moderate", action="store_true")
args = parser.parse_args()

models = get_model_list()

demo = build_demo()
demo.queue(concurrency_count=args.concurrency_count, status_update_rate=10,
api_open=False).launch(
Expand Down

0 comments on commit df8132d

Please sign in to comment.