Skip to content

Commit

Permalink
chore: slightly refine the codebase
Browse files Browse the repository at this point in the history
  • Loading branch information
cleardusk committed Jul 5, 2024
1 parent 669487a commit d09527c
Show file tree
Hide file tree
Showing 10 changed files with 35 additions and 39 deletions.
8 changes: 4 additions & 4 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def partial_fields(target_class, kwargs):
#################### interface logic ####################

# Define components first
eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eye-close ratio")
lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-close ratio")
eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eyes-open ratio")
lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-open ratio")
retargeting_input_image = gr.Image(type="numpy")
output_image = gr.Image(type="numpy")
output_image_paste_back = gr.Image(type="numpy")
Expand All @@ -56,15 +56,15 @@ def partial_fields(target_class, kwargs):
gr.HTML(load_description(title_md))
gr.Markdown(load_description("assets/gradio_description_upload.md"))
with gr.Row():
with gr.Accordion(open=True, label="Reference Portrait"):
with gr.Accordion(open=True, label="Source Portrait"):
image_input = gr.Image(type="filepath")
with gr.Accordion(open=True, label="Driving Video"):
video_input = gr.Video()
gr.Markdown(load_description("assets/gradio_description_animation.md"))
with gr.Row():
with gr.Accordion(open=True, label="Animation Options"):
with gr.Row():
flag_relative_input = gr.Checkbox(value=True, label="relative pose")
flag_relative_input = gr.Checkbox(value=True, label="relative motion")
flag_do_crop_input = gr.Checkbox(value=True, label="do crop")
flag_remap_input = gr.Checkbox(value=True, label="paste-back")
with gr.Row():
Expand Down
4 changes: 2 additions & 2 deletions assets/gradio_description_animation.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<span style="font-size: 1.2em;">🔥 To animate the reference portrait with the driving video, please follow these steps:</span>
<span style="font-size: 1.2em;">🔥 To animate the source portrait with the driving video, please follow these steps:</span>
<div style="font-size: 1.2em; margin-left: 20px;">
1. Specify the options in the <strong>Animation Options</strong> section. We recommend checking the <strong>do crop</strong> option when facial areas occupy a relatively small portion of your image.
</div>
<div style="font-size: 1.2em; margin-left: 20px;">
2. Press the <strong>🚀 Animate</strong> button and wait for a moment. Your animated video will appear in the result block. This may take a few moments.
</div>
</div>
2 changes: 1 addition & 1 deletion assets/gradio_description_retargeting.md
Original file line number Diff line number Diff line change
@@ -1 +1 @@
<span style="font-size: 1.2em;">🔥 To change the target eye-close and lip-close ratio of the reference portrait, please drag the sliders and then click the <strong>🚗 Retargeting</strong> button. The result would be shown in the middle block. You can try running it multiple times. <strong>😊 Set both ratios to 0.8 to see what's going on!</strong> </span>
<span style="font-size: 1.2em;">🔥 To change the target eyes-open and lip-open ratio of the source portrait, please drag the sliders and then click the <strong>🚗 Retargeting</strong> button. The result would be shown in the middle block. You can try running it multiple times. <strong>😊 Set both ratios to 0.8 to see what's going on!</strong> </span>
6 changes: 2 additions & 4 deletions assets/gradio_description_upload.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
## 🤗 This is the official gradio demo for **Live Portrait**.
### Guidance for the gradio page:
<div style="font-size: 1.2em;">Please upload or use the webcam to get a reference portrait to the <strong>Reference Portrait</strong> field and a driving video to the <strong>Driving Video</strong> field.</div>

## 🤗 This is the official gradio demo for **LivePortrait**.
<div style="font-size: 1.2em;">Please upload or use the webcam to get a source portrait to the <strong>Source Portrait</strong> field and a driving video to the <strong>Driving Video</strong> field.</div>
6 changes: 3 additions & 3 deletions src/config/argument_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
@dataclass(repr=False) # use repr from PrintableConfig
class ArgumentConfig(PrintableConfig):
########## input arguments ##########
source_image: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s6.jpg') # path to the reference portrait
source_image: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s6.jpg') # path to the source portrait
driving_info: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format)
output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video
#####################################
Expand All @@ -25,9 +25,9 @@ class ArgumentConfig(PrintableConfig):
flag_eye_retargeting: bool = False
flag_lip_retargeting: bool = False
flag_stitching: bool = True # we recommend setting it to True!
flag_relative: bool = True # whether to use relative pose
flag_relative: bool = True # whether to use relative motion
flag_pasteback: bool = True # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
flag_do_crop: bool = True # whether to crop the reference portrait to the face-cropping space
flag_do_crop: bool = True # whether to crop the source portrait to the face-cropping space
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
#########################################

Expand Down
4 changes: 2 additions & 2 deletions src/config/inference_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class InferenceConfig(PrintableConfig):
flag_lip_retargeting: bool = False
flag_stitching: bool = True # we recommend setting it to True!

flag_relative: bool = True # whether to use relative pose
flag_relative: bool = True # whether to use relative motion
anchor_frame: int = 0 # set this value if find_best_frame is True

input_shape: Tuple[int, int] = (256, 256) # input shape
Expand All @@ -45,5 +45,5 @@ class InferenceConfig(PrintableConfig):
ref_shape_n: int = 2

device_id: int = 0
flag_do_crop: bool = False # whether to crop the reference portrait to the face-cropping space
flag_do_crop: bool = False # whether to crop the source portrait to the face-cropping space
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
16 changes: 8 additions & 8 deletions src/gradio_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, inference_cfg, crop_cfg, args: ArgumentConfig):
self.mask_ori = None
self.img_rgb = None
self.crop_M_c2o = None


def execute_video(
self,
Expand All @@ -62,9 +62,9 @@ def execute_video(
# video driven animation
video_path, video_path_concat = self.execute(self.args)
gr.Info("Run successfully!", duration=2)
return video_path, video_path_concat,
return video_path, video_path_concat,
else:
raise gr.Error("The input reference portrait or driving video hasn't been prepared yet 💥!", duration=5)
raise gr.Error("The input source portrait or driving video hasn't been prepared yet 💥!", duration=5)

def execute_image(self, input_eye_ratio: float, input_lip_ratio: float):
""" for single image retargeting
Expand All @@ -74,12 +74,12 @@ def execute_image(self, input_eye_ratio: float, input_lip_ratio: float):
elif self.f_s_user is None:
if self.start_prepare:
raise gr.Error(
"The reference portrait is under processing 💥! Please wait for a second.",
"The source portrait is under processing 💥! Please wait for a second.",
duration=5
)
else:
raise gr.Error(
"The reference portrait hasn't been prepared yet 💥! Please scroll to the top of the page to upload.",
"The source portrait hasn't been prepared yet 💥! Please scroll to the top of the page to upload.",
duration=5
)
else:
Expand All @@ -98,7 +98,7 @@ def execute_image(self, input_eye_ratio: float, input_lip_ratio: float):
out_to_ori_blend = paste_back(out, self.crop_M_c2o, self.img_rgb, self.mask_ori)
gr.Info("Run successfully!", duration=2)
return out, out_to_ori_blend


def prepare_retargeting(self, input_image_path, flag_do_crop = True):
""" for single image retargeting
Expand All @@ -107,7 +107,7 @@ def prepare_retargeting(self, input_image_path, flag_do_crop = True):
gr.Info("Upload successfully!", duration=2)
self.start_prepare = True
inference_cfg = self.live_portrait_wrapper.cfg
######## process reference portrait ########
######## process source portrait ########
img_rgb = load_img_online(input_image_path, mode='rgb', max_dim=1280, n=16)
log(f"Load source image from {input_image_path}.")
crop_info = self.cropper.crop_single_image(img_rgb)
Expand All @@ -125,7 +125,7 @@ def prepare_retargeting(self, input_image_path, flag_do_crop = True):
self.x_s_info_user = x_s_info
self.source_lmk_user = crop_info['lmk_crop']
self.img_rgb = img_rgb
self.crop_M_c2o = crop_info['M_c2o']
self.crop_M_c2o = crop_info['M_c2o']
self.mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
# update slider
eye_close_ratio = calc_eye_close_ratio(self.source_lmk_user[None])
Expand Down
2 changes: 1 addition & 1 deletion src/live_portrait_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, inference_cfg: InferenceConfig, crop_cfg: CropConfig):

def execute(self, args: ArgumentConfig):
inference_cfg = self.live_portrait_wrapper.cfg # for convenience
######## process reference portrait ########
######## process source portrait ########
img_rgb = load_image_rgb(args.source_image)
img_rgb = resize_to_limit(img_rgb, inference_cfg.ref_max_shape, inference_cfg.ref_shape_n)
log(f"Load source image from {args.source_image}")
Expand Down
10 changes: 5 additions & 5 deletions src/live_portrait_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
import torch
import yaml

from src.utils.timer import Timer
from src.utils.helper import load_model, concat_feat
from src.utils.camera import headpose_pred_to_degree, get_rotation_matrix
from .utils.timer import Timer
from .utils.helper import load_model, concat_feat
from .utils.camera import headpose_pred_to_degree, get_rotation_matrix
from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
from src.config.inference_config import InferenceConfig
from src.utils.rprint import rlog as log
from .config.inference_config import InferenceConfig
from .utils.rprint import rlog as log


class LivePortraitWrapper(object):
Expand Down
16 changes: 7 additions & 9 deletions src/utils/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,14 @@

import os
import os.path as osp
import cv2
import torch
from rich.console import Console
from collections import OrderedDict

from src.modules.spade_generator import SPADEDecoder
from src.modules.warping_network import WarpingNetwork
from src.modules.motion_extractor import MotionExtractor
from src.modules.appearance_feature_extractor import AppearanceFeatureExtractor
from src.modules.stitching_retargeting_network import StitchingRetargetingNetwork
from .rprint import rlog as log
from ..modules.spade_generator import SPADEDecoder
from ..modules.warping_network import WarpingNetwork
from ..modules.motion_extractor import MotionExtractor
from ..modules.appearance_feature_extractor import AppearanceFeatureExtractor
from ..modules.stitching_retargeting_network import StitchingRetargetingNetwork


def suffix(filename):
Expand Down Expand Up @@ -45,6 +42,7 @@ def is_video(file_path):
return True
return False


def is_template(file_path):
if file_path.endswith(".pkl"):
return True
Expand Down Expand Up @@ -149,8 +147,8 @@ def calculate_transformation(config, s_kp_info, t_0_kp_info, t_i_kp_info, R_s, R
new_scale = s_kp_info['scale'] * (t_i_kp_info['scale'] / t_0_kp_info['scale'])
return new_rotation, new_expression, new_translation, new_scale


def load_description(fp):
with open(fp, 'r', encoding='utf-8') as f:
content = f.read()
return content

0 comments on commit d09527c

Please sign in to comment.