Skip to content

Commit

Permalink
1148 loading status (gradio-app#1164)
Browse files Browse the repository at this point in the history
* update loader

* fix for non queued statuses

* remove logs

* Update demo/fake_gan/run.py
  • Loading branch information
pngwn authored May 5, 2022
1 parent a9610a4 commit 56222fb
Show file tree
Hide file tree
Showing 29 changed files with 392 additions and 169 deletions.
31 changes: 17 additions & 14 deletions demo/fake_gan/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def fake_gan(count, *args):
return images


cheetah = os.path.join(os.path.dirname(__file__), "files/cheetah1.jpg")

demo = gr.Interface(
fn=fake_gan,
inputs=[
Expand All @@ -38,21 +40,22 @@ def fake_gan(count, *args):
title="FD-GAN",
description="This is a fake demo of a GAN. In reality, the images are randomly chosen from Unsplash.",
examples=[
[os.path.join(os.path.dirname(__file__), "files/cheetah1.jpg"), 12, 12, 4, 4],
[os.path.join(os.path.dirname(__file__), "files/cheetah1.jpg"), 12, 12, 4, 4],
[os.path.join(os.path.dirname(__file__), "files/cheetah1.jpg"), 12, 12, 4, 4],
[os.path.join(os.path.dirname(__file__), "files/cheetah1.jpg"), 12, 12, 4, 4],
[os.path.join(os.path.dirname(__file__), "files/cheetah1.jpg"), 12, 12, 4, 4],
[os.path.join(os.path.dirname(__file__), "files/cheetah1.jpg"), 12, 12, 4, 4],
[os.path.join(os.path.dirname(__file__), "files/cheetah1.jpg"), 12, 12, 4, 4],
[os.path.join(os.path.dirname(__file__), "files/cheetah1.jpg"), 12, 12, 4, 4],
[os.path.join(os.path.dirname(__file__), "files/cheetah1.jpg"), 12, 12, 4, 4],
[os.path.join(os.path.dirname(__file__), "files/cheetah1.jpg"), 12, 12, 4, 4],
[os.path.join(os.path.dirname(__file__), "files/cheetah1.jpg"), 12, 12, 4, 4],
[os.path.join(os.path.dirname(__file__), "files/cheetah1.jpg"), 12, 12, 4, 4],
[os.path.join(os.path.dirname(__file__), "files/cheetah1.jpg"), 12, 12, 4, 4],
[os.path.join(os.path.dirname(__file__), "files/cheetah1.jpg"), 12, 12, 4, 4],
[2, cheetah, 12, 12, 4, 4],
[2, cheetah, 12, 12, 4, 4],
[2, cheetah, 12, 12, 4, 4],
[2, cheetah, 12, 12, 4, 4],
[2, cheetah, 12, 12, 4, 4],
[2, cheetah, 12, 12, 4, 4],
[2, cheetah, 12, 12, 4, 4],
[2, cheetah, 12, 12, 4, 4],
[2, cheetah, 12, 12, 4, 4],
[2, cheetah, 12, 12, 4, 4],
[2, cheetah, 12, 12, 4, 4],
[2, cheetah, 12, 12, 4, 4],
[2, cheetah, 12, 12, 4, 4],
[2, cheetah, 12, 12, 4, 4],
],
enable_queue=True,
)

if __name__ == "__main__":
Expand Down
62 changes: 33 additions & 29 deletions ui/packages/app/src/Blocks.svelte
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
<script lang="ts">
import type { SvelteComponentTyped } from "svelte";
import { component_map } from "./components/directory";
import { loading_status } from "./stores";
import type { LoadingStatus } from "./stores";
import { _ } from "svelte-i18n";
import { setupi18n } from "./i18n";
import Render from "./Render.svelte";
Expand Down Expand Up @@ -83,7 +86,6 @@
}
};
}, {} as { [id: number]: Instance });
console.log(JSON.stringify(components));
function load_component<T extends keyof typeof component_map>(
name: T
Expand Down Expand Up @@ -145,7 +147,7 @@
let handled_dependencies: Array<number[]> = [];
let status_tracker_values: Record<number, string> = {};
async function handle_mount({ detail }) {
async function handle_mount() {
await tick();
dependencies.forEach(
(
Expand All @@ -166,18 +168,16 @@
outputs.every((v) => instance_map[v].instance) &&
inputs.every((v) => instance_map[v].instance)
) {
fn(
"predict",
fn({
action: "predict",
backend_fn,
frontend_fn,
{
payload: {
fn_index: i,
data: inputs.map((id) => instance_map[id].value)
},
outputs.map((id) => instance_map[id].value),
queue === null ? enable_queue : queue,
() => {}
)
queue: queue === null ? enable_queue : queue
})
.then((output) => {
output.data.forEach((value, i) => {
instance_map[outputs[i]].value = value;
Expand All @@ -193,40 +193,30 @@
target_instances.forEach(([id, { instance }]: [number, Instance]) => {
if (handled_dependencies[i]?.includes(id) || !instance) return;
instance?.$on(trigger, () => {
if (status === "pending") {
console.log(loading_status.get_status_for_fn(i));
if (loading_status.get_status_for_fn(i) === "pending") {
return;
}
outputs.forEach((_id) =>
set_prop(instance_map[_id], "loading_status", "pending")
);
fn(
"predict",
// page events
fn({
action: "predict",
backend_fn,
frontend_fn,
{
payload: {
fn_index: i,
data: inputs.map((id) => instance_map[id].value)
},
outputs.map((id) => instance_map[id].value),
queue === null ? enable_queue : queue,
() => {}
)
output_data: outputs.map((id) => instance_map[id].value),
queue: queue === null ? enable_queue : queue
})
.then((output) => {
output.data.forEach((value, i) => {
instance_map[outputs[i]].value = value;
set_prop(
instance_map[outputs[i]],
"loading_status",
"complete"
);
});
})
.catch((error) => {
outputs.forEach((_id) =>
set_prop(instance_map[_id], "loading_status", "error")
);
console.error(error);
});
});
Expand All @@ -243,6 +233,20 @@
return dep.filter((_id) => _id !== id);
});
}
$: set_status($loading_status);
dependencies.forEach((v, i) => {
loading_status.register(i, v.outputs);
});
function set_status(
statuses: Record<number, Omit<LoadingStatus, "outputs">>
) {
for (const id in statuses) {
set_prop(instance_map[id], "loading_status", statuses[id]);
}
}
</script>

<svelte:head>
Expand Down
1 change: 1 addition & 0 deletions ui/packages/app/src/Render.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
node.props.form_position = "single";
}
}
children =
children &&
children.filter((v) => instance_map[v.id].type !== "statustracker");
Expand Down
151 changes: 108 additions & 43 deletions ui/packages/app/src/api.ts
Original file line number Diff line number Diff line change
@@ -1,71 +1,136 @@
import { LoadingStatus, loading_status } from "./stores";

type StatusResponse =
| {
status: "COMPLETE";
data: { duration: number; average_duration: number; data: unknown };
}
| {
status: "QUEUED";
data: number;
}
| {
status: "PENDING";
data: null;
}
| {
status: "FAILED";
data: Record<string, unknown>;
};

interface Payload {
data: Record<string, unknown>;
fn_index: number;
}

function delay(n: number) {
return new Promise(function (resolve) {
setTimeout(resolve, n * 1000);
});
}

let postData = async (url: string, body: unknown) => {
const output = await fetch(url, {
async function post_data<
Return extends Record<string, unknown> = Record<string, unknown>
>(url: string, body: unknown): Promise<Return> {
const response = await fetch(url, {
method: "POST",
body: JSON.stringify(body),
headers: { "Content-Type": "application/json" }
});

if (response.status !== 200) {
throw new Error(response.statusText);
}

const output: Return = await response.json();

return output;
};
}

export const fn = async (
session_hash: string,
api_endpoint: string,
action: string,
backend_fn: boolean,
frontend_fn: Function | undefined,
data: Record<string, unknown>,
output_data: Array<any>,
queue: boolean,
queue_callback: (pos: number | null, is_initial?: boolean) => void
{
action,
payload,
queue,
backend_fn,
frontend_fn,
output_data
}: {
action: string;
payload: Payload;
queue: boolean;
backend_fn: boolean;
frontend_fn: Function | undefined;
output_data: Array<any>;
}
) => {
const fn_index = payload.fn_index;

if (frontend_fn !== undefined) {
data.data = frontend_fn(data.data.concat(output_data));
payload.data = frontend_fn(payload.data.concat(output_data));
}
if (backend_fn == false) {
return data;
return payload;
}
data["session_hash"] = session_hash;

if (queue && ["predict", "interpret"].includes(action)) {
data["action"] = action;
const output = await postData(api_endpoint + "queue/push/", data);
const output_json = await output.json();
let [hash, queue_position] = [
output_json["hash"],
output_json["queue_position"]
];
queue_callback(queue_position, /*is_initial=*/ true);
let status = "UNKNOWN";
while (status != "COMPLETE" && status != "FAILED") {
if (status != "UNKNOWN") {
await delay(1);
}
const status_response = await postData(api_endpoint + "queue/status/", {
hash: hash
});
var status_obj = await status_response.json();
status = status_obj["status"];
loading_status.update(fn_index as number, "pending", null, null);

const { hash, queue_position } = await post_data<{
hash: string;
queue_position: number;
}>(api_endpoint + "queue/push/", { ...payload, action, session_hash });

loading_status.update(fn_index, "pending", queue_position, null);

for (;;) {
await delay(1);

const { status, data } = await post_data<StatusResponse>(
api_endpoint + "queue/status/",
{
hash: hash
}
);

if (status === "QUEUED") {
queue_callback(status_obj["data"]);
loading_status.update(fn_index, "pending", data, null);
} else if (status === "PENDING") {
queue_callback(null);
loading_status.update(fn_index, "pending", 0, null);
} else if (status === "FAILED") {
loading_status.update(fn_index, "error", null, null);

throw new Error(status);
} else {
loading_status.update(
fn_index,
"complete",
null,
data.average_duration
);

return data;
}
}
if (status == "FAILED") {
throw new Error(status);
} else {
return status_obj["data"];
}
} else {
const output = await postData(api_endpoint + action + "/", data);
if (output.status !== 200) {
throw new Error(output.statusText);
}
return await output.json();
loading_status.update(fn_index as number, "pending", null, null);

const output = await post_data(api_endpoint + action + "/", {
...payload,
session_hash
});

console.log();

loading_status.update(
fn_index,
"complete",
null,
output.average_duration as number
);

return await output;
}
};
6 changes: 4 additions & 2 deletions ui/packages/app/src/components/Audio/Audio.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import { Block } from "@gradio/atoms";
import StatusTracker from "../StatusTracker/StatusTracker.svelte";
import type { LoadingStatus } from "../StatusTracker/types";
import { _ } from "svelte-i18n";
export let mode: "static" | "dynamic";
Expand All @@ -18,7 +20,7 @@
export let root: string;
export let show_label: boolean;
export let loading_status: "complete" | "pending" | "error";
export let loading_status: LoadingStatus;
if (default_value) value = default_value;
Expand All @@ -35,7 +37,7 @@
color={dragging ? "green" : "grey"}
padding={false}
>
<StatusTracker tracked_status={loading_status} />
<StatusTracker {...loading_status} />

{#if mode === "dynamic"}
<Audio
Expand Down
Loading

0 comments on commit 56222fb

Please sign in to comment.