Skip to content

Commit

Permalink
feat: support human template upload for Gradio (KwaiVGI#302)
Browse files Browse the repository at this point in the history
* support upload motion template to fast infer

* Update gradio_pipeline.py

---------

Co-authored-by: Mystery099 <[email protected]>
  • Loading branch information
iflamed and Mystery099 authored Aug 9, 2024
1 parent b95d7b6 commit f24b6ff
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 19 deletions.
52 changes: 37 additions & 15 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,20 +180,40 @@ def reset_sliders(*args, **kwargs):
vy_ratio = gr.Number(value=-0.125, label="source crop y", minimum=-0.5, maximum=0.5, step=0.01)

with gr.Column():
with gr.Accordion(open=True, label="Driving Video"):
driving_video_input = gr.Video()
gr.Examples(
examples=[
[osp.join(example_video_dir, "d0.mp4")],
[osp.join(example_video_dir, "d18.mp4")],
[osp.join(example_video_dir, "d19.mp4")],
[osp.join(example_video_dir, "d14.mp4")],
[osp.join(example_video_dir, "d6.mp4")],
[osp.join(example_video_dir, "d20.mp4")],
],
inputs=[driving_video_input],
cache_examples=False,
)
with gr.Tabs():
with gr.TabItem("🎞️ Driving Video") as v_tab_video:
with gr.Accordion(open=True, label="Driving Video"):
driving_video_input = gr.Video()
gr.Examples(
examples=[
[osp.join(example_video_dir, "d0.mp4")],
[osp.join(example_video_dir, "d18.mp4")],
[osp.join(example_video_dir, "d19.mp4")],
[osp.join(example_video_dir, "d14.mp4")],
[osp.join(example_video_dir, "d6.mp4")],
[osp.join(example_video_dir, "d20.mp4")],
],
inputs=[driving_video_input],
cache_examples=False,
)
with gr.TabItem("📁 Driving Pickle") as v_tab_pickle:
with gr.Accordion(open=True, label="Driving Pickle"):
driving_video_pickle_input = gr.File(type="filepath", file_types=[".pkl"])
gr.Examples(
examples=[
[osp.join(example_video_dir, "d1.pkl")],
[osp.join(example_video_dir, "d2.pkl")],
[osp.join(example_video_dir, "d5.pkl")],
[osp.join(example_video_dir, "d7.pkl")],
[osp.join(example_video_dir, "d8.pkl")],
],
inputs=[driving_video_pickle_input],
cache_examples=False,
)

v_tab_selection = gr.Textbox(visible=False)
v_tab_pickle.select(lambda: "Pickle", None, v_tab_selection)
v_tab_video.select(lambda: "Video", None, v_tab_selection)
# with gr.Accordion(open=False, label="Animation Instructions"):
# gr.Markdown(load_description("assets/gradio/gradio_description_animation.md"))
with gr.Accordion(open=True, label="Cropping Options for Driving Video"):
Expand Down Expand Up @@ -225,7 +245,7 @@ def reset_sliders(*args, **kwargs):
with gr.Accordion(open=True, label="The animated video"):
output_video_concat_i2v.render()
with gr.Row():
process_button_reset = gr.ClearButton([source_image_input, source_video_input, driving_video_input, output_video_i2v, output_video_concat_i2v], value="🧹 Clear")
process_button_reset = gr.ClearButton([source_image_input, source_video_input, driving_video_pickle_input, driving_video_input, output_video_i2v, output_video_concat_i2v], value="🧹 Clear")

with gr.Row():
# Examples
Expand Down Expand Up @@ -393,6 +413,7 @@ def reset_sliders(*args, **kwargs):
inputs=[
source_image_input,
source_video_input,
driving_video_pickle_input,
driving_video_input,
flag_relative_input,
flag_do_crop_input,
Expand All @@ -410,6 +431,7 @@ def reset_sliders(*args, **kwargs):
vy_ratio_crop_driving_video,
driving_smooth_observation_variance,
tab_selection,
v_tab_selection,
],
outputs=[output_video_i2v, output_video_concat_i2v],
show_progress=True
Expand Down
17 changes: 13 additions & 4 deletions src/gradio_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def execute_video(
self,
input_source_image_path=None,
input_source_video_path=None,
input_driving_video_pickle_path=None,
input_driving_video_path=None,
flag_relative_input=True,
flag_do_crop_input=True,
Expand All @@ -163,8 +164,9 @@ def execute_video(
vy_ratio_crop_driving_video=-0.1,
driving_smooth_observation_variance=3e-7,
tab_selection=None,
v_tab_selection=None
):
""" for video-driven potrait animation or video editing
""" for video-driven portrait animation or video editing
"""
if tab_selection == 'Image':
input_source_path = input_source_image_path
Expand All @@ -173,15 +175,22 @@ def execute_video(
else:
input_source_path = input_source_image_path

if input_source_path is not None and input_driving_video_path is not None:
if osp.exists(input_driving_video_path) and is_square_video(input_driving_video_path) is False:
if v_tab_selection == 'Video':
input_driving_path = input_driving_video_path
elif v_tab_selection == 'Pickle':
input_driving_path = input_driving_video_pickle_path
else:
input_driving_path = input_driving_video_path

if input_source_path is not None and input_driving_path is not None:
if osp.exists(input_driving_path) and v_tab_selection == 'Video' and is_square_video(input_driving_path) is False:
flag_crop_driving_video_input = True
log("The driving video is not square, it will be cropped to square automatically.")
gr.Info("The driving video is not square, it will be cropped to square automatically.", duration=2)

args_user = {
'source': input_source_path,
'driving': input_driving_video_path,
'driving': input_driving_path,
'flag_relative_motion': flag_relative_input,
'flag_do_crop': flag_do_crop_input,
'flag_pasteback': flag_remap_input,
Expand Down

0 comments on commit f24b6ff

Please sign in to comment.