From c92a25e5dd62f72df1993ffd76a192d8d9305a72 Mon Sep 17 00:00:00 2001 From: AricGamma Date: Tue, 25 Jun 2024 00:04:40 +0800 Subject: [PATCH] feat: merge config with cli args (#89) Co-authored-by: leeway.zlw --- configs/inference/default.yaml | 12 ++++++------ hallo/utils/config.py | 25 +++++++++++++++++++++++++ scripts/inference.py | 21 +++++++++++---------- 3 files changed, 42 insertions(+), 16 deletions(-) create mode 100644 hallo/utils/config.py diff --git a/configs/inference/default.yaml b/configs/inference/default.yaml index 5a7835ba..f755bde4 100644 --- a/configs/inference/default.yaml +++ b/configs/inference/default.yaml @@ -1,5 +1,5 @@ -source_image: ./default.png -driving_audio: default.wav +source_image: examples/reference_images/1.jpg +driving_audio: examples/driving_audios/1.wav weight_dtype: fp16 @@ -38,10 +38,10 @@ vae: save_path: ./.cache -face_expand_ratio: 1.1 -pose_weight: 1.1 -face_weight: 1.1 -lip_weight: 1.1 +face_expand_ratio: 1.2 +pose_weight: 1.0 +face_weight: 1.0 +lip_weight: 1.0 unet_additional_kwargs: use_inflated_groupnorm: true diff --git a/hallo/utils/config.py b/hallo/utils/config.py new file mode 100644 index 00000000..69854b61 --- /dev/null +++ b/hallo/utils/config.py @@ -0,0 +1,25 @@ +""" +This module provides utility functions for configuration manipulation. +""" + +from typing import Dict + + +def filter_non_none(dict_obj: Dict): + """ + Filters out key-value pairs from the given dictionary where the value is None. + + Args: + dict_obj (Dict): The dictionary to be filtered. + + Returns: + Dict: The dictionary with key-value pairs removed where the value was None. + + This function creates a new dictionary containing only the key-value pairs from + the original dictionary where the value is not None. It then clears the original + dictionary and updates it with the filtered key-value pairs. + """ + non_none_filter = { k: v for k, v in dict_obj.items() if v is not None } + dict_obj.clear() + dict_obj.update(non_none_filter) + return dict_obj diff --git a/scripts/inference.py b/scripts/inference.py index 8dfefbf8..5b780b1e 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -44,6 +44,7 @@ from hallo.models.image_proj import ImageProjModel from hallo.models.unet_2d_condition import UNet2DConditionModel from hallo.models.unet_3d import UNet3DConditionModel +from hallo.utils.config import filter_non_none from hallo.utils.util import tensor_to_video @@ -125,16 +126,16 @@ def inference_process(args: argparse.Namespace): modules and variables to prepare for the upcoming inference steps. """ # 1. init config + cli_args = filter_non_none(vars(args)) config = OmegaConf.load(args.config) - config = OmegaConf.merge(config, vars(args)) + config = OmegaConf.merge(config, cli_args) source_image_path = config.source_image driving_audio_path = config.driving_audio save_path = config.save_path if not os.path.exists(save_path): os.makedirs(save_path) motion_scale = [config.pose_weight, config.face_weight, config.lip_weight] - if args.checkpoint is not None: - config.audio_ckpt_dir = args.checkpoint + # 2. runtime variables device = torch.device( "cuda") if torch.cuda.is_available() else torch.device("cpu") @@ -353,21 +354,21 @@ def inference_process(args: argparse.Namespace): parser.add_argument( "-c", "--config", default="configs/inference/default.yaml") parser.add_argument("--source_image", type=str, required=False, - help="source image", default="test_data/source_images/6.jpg") + help="source image") parser.add_argument("--driving_audio", type=str, required=False, - help="driving audio", default="test_data/driving_audios/singing/sing_4.wav") + help="driving audio") parser.add_argument( "--output", type=str, help="output video file name", default=".cache/output.mp4") parser.add_argument( - "--pose_weight", type=float, help="weight of pose", default=1.0) + "--pose_weight", type=float, help="weight of pose", required=False) parser.add_argument( - "--face_weight", type=float, help="weight of face", default=1.0) + "--face_weight", type=float, help="weight of face", required=False) parser.add_argument( - "--lip_weight", type=float, help="weight of lip", default=1.0) + "--lip_weight", type=float, help="weight of lip", required=False) parser.add_argument( - "--face_expand_ratio", type=float, help="face region", default=1.2) + "--face_expand_ratio", type=float, help="face region", required=False) parser.add_argument( - "--checkpoint", type=str, help="which checkpoint", default=None) + "--audio_ckpt_dir", "--checkpoint", type=str, help="specific checkpoint dir", required=False) command_line_args = parser.parse_args()