Skip to content

Commit

Permalink
feat: training code of hallo (fudan-generative-vision#101)
Browse files Browse the repository at this point in the history
* add training code and corresponding config yaml files
* add some auxiliary funciton in utils
* fix a parameter bug in motion module
* fix mask size issue in stage2 dataset module "talk_video.py"
  • Loading branch information
xumingw authored Jun 26, 2024
1 parent cfd1815 commit 0152cd9
Show file tree
Hide file tree
Showing 8 changed files with 2,114 additions and 4 deletions.
63 changes: 63 additions & 0 deletions configs/train/stage1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
data:
train_bs: 8
train_width: 512
train_height: 512
meta_paths:
- "./data/HDTF_meta.json"
# Margin of frame indexes between ref and tgt images
sample_margin: 30

solver:
gradient_accumulation_steps: 1
mixed_precision: "no"
enable_xformers_memory_efficient_attention: True
gradient_checkpointing: False
max_train_steps: 30000
max_grad_norm: 1.0
# lr
learning_rate: 1.0e-5
scale_lr: False
lr_warmup_steps: 1
lr_scheduler: "constant"

# optimizer
use_8bit_adam: False
adam_beta1: 0.9
adam_beta2: 0.999
adam_weight_decay: 1.0e-2
adam_epsilon: 1.0e-8

val:
validation_steps: 500

noise_scheduler_kwargs:
num_train_timesteps: 1000
beta_start: 0.00085
beta_end: 0.012
beta_schedule: "scaled_linear"
steps_offset: 1
clip_sample: false

base_model_path: "./pretrained_models/stable-diffusion-v1-5/"
vae_model_path: "./pretrained_models/sd-vae-ft-mse"
face_analysis_model_path: "./pretrained_models/face_analysis"

weight_dtype: "fp16" # [fp16, fp32]
uncond_ratio: 0.1
noise_offset: 0.05
snr_gamma: 5.0
enable_zero_snr: True
face_locator_pretrained: False

seed: 42
resume_from_checkpoint: "latest"
checkpointing_steps: 500
exp_name: "stage1"
output_dir: "./exp_output"

ref_image_paths:
- "examples/reference_images/1.jpg"

mask_image_paths:
- "examples/masks/1.png"

119 changes: 119 additions & 0 deletions configs/train/stage2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
data:
train_bs: 4
val_bs: 1
train_width: 512
train_height: 512
fps: 25
sample_rate: 16000
n_motion_frames: 2
n_sample_frames: 14
audio_margin: 2
train_meta_paths:
- "./data/hdtf_split_stage2.json"

wav2vec_config:
audio_type: "vocals" # audio vocals
model_scale: "base" # base large
features: "all" # last avg all
model_path: ./pretrained_models/wav2vec/wav2vec2-base-960h
audio_separator:
model_path: ./pretrained_models/audio_separator/Kim_Vocal_2.onnx
face_expand_ratio: 1.2

solver:
gradient_accumulation_steps: 1
mixed_precision: "no"
enable_xformers_memory_efficient_attention: True
gradient_checkpointing: True
max_train_steps: 30000
max_grad_norm: 1.0
# lr
learning_rate: 1e-5
scale_lr: False
lr_warmup_steps: 1
lr_scheduler: "constant"

# optimizer
use_8bit_adam: True
adam_beta1: 0.9
adam_beta2: 0.999
adam_weight_decay: 1.0e-2
adam_epsilon: 1.0e-8

val:
validation_steps: 1000

noise_scheduler_kwargs:
num_train_timesteps: 1000
beta_start: 0.00085
beta_end: 0.012
beta_schedule: "linear"
steps_offset: 1
clip_sample: false

unet_additional_kwargs:
use_inflated_groupnorm: true
unet_use_cross_frame_attention: false
unet_use_temporal_attention: false
use_motion_module: true
use_audio_module: true
motion_module_resolutions:
- 1
- 2
- 4
- 8
motion_module_mid_block: true
motion_module_decoder_only: false
motion_module_type: Vanilla
motion_module_kwargs:
num_attention_heads: 8
num_transformer_block: 1
attention_block_types:
- Temporal_Self
- Temporal_Self
temporal_position_encoding: true
temporal_position_encoding_max_len: 32
temporal_attention_dim_div: 1
audio_attention_dim: 768
stack_enable_blocks_name:
- "up"
- "down"
- "mid"
stack_enable_blocks_depth: [0,1,2,3]

trainable_para:
- audio_modules
- motion_modules

base_model_path: "./pretrained_models/stable-diffusion-v1-5/"
vae_model_path: "./pretrained_models/sd-vae-ft-mse"
face_analysis_model_path: "./pretrained_models/face_analysis"
mm_path: "./pretrained_models/motion_module/mm_sd_v15_v2.ckpt"

weight_dtype: "fp16" # [fp16, fp32]
uncond_img_ratio: 0.05
uncond_audio_ratio: 0.05
uncond_ia_ratio: 0.05
start_ratio: 0.05
noise_offset: 0.05
snr_gamma: 5.0
enable_zero_snr: True
stage1_ckpt_dir: "./pretrained_models/hallo/stage1"

single_inference_times: 10
inference_steps: 40
cfg_scale: 3.5

seed: 42
resume_from_checkpoint: "latest"
checkpointing_steps: 500
exp_name: "stage2_test"
output_dir: "./exp_output"

ref_img_path:
- "examples/reference_images/1.jpg"

audio_path:
- "examples/driving_audios/1.wav"


Binary file added examples/masks/1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
12 changes: 8 additions & 4 deletions hallo/datasets/talk_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,25 +145,29 @@ def __init__(
)
self.attn_transform_64 = transforms.Compose(
[
transforms.Resize((64,64)),
transforms.Resize(
(self.img_size[0] // 8, self.img_size[0] // 8)),
transforms.ToTensor(),
]
)
self.attn_transform_32 = transforms.Compose(
[
transforms.Resize((32, 32)),
transforms.Resize(
(self.img_size[0] // 16, self.img_size[0] // 16)),
transforms.ToTensor(),
]
)
self.attn_transform_16 = transforms.Compose(
[
transforms.Resize((16, 16)),
transforms.Resize(
(self.img_size[0] // 32, self.img_size[0] // 32)),
transforms.ToTensor(),
]
)
self.attn_transform_8 = transforms.Compose(
[
transforms.Resize((8, 8)),
transforms.Resize(
(self.img_size[0] // 64, self.img_size[0] // 64)),
transforms.ToTensor(),
]
)
Expand Down
1 change: 1 addition & 0 deletions hallo/models/motion_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,7 @@ def extra_repr(self):
def set_use_memory_efficient_attention_xformers(
self,
use_memory_efficient_attention_xformers: bool,
attention_op = None,
):
"""
Sets the use of memory-efficient attention xformers for the VersatileAttention class.
Expand Down
148 changes: 148 additions & 0 deletions hallo/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
import subprocess
import sys
from pathlib import Path
from typing import List

import av
import cv2
Expand Down Expand Up @@ -614,3 +615,150 @@ def get_face_region(image_path: str, detector):
except Exception as e:
print(f"Error processing image {image_path}: {e}")
return None, None


def save_checkpoint(model: torch.nn.Module, save_dir: str, prefix: str, ckpt_num: int, total_limit: int = -1) -> None:
"""
Save the model's state_dict to a checkpoint file.
If `total_limit` is provided, this function will remove the oldest checkpoints
until the total number of checkpoints is less than the specified limit.
Args:
model (nn.Module): The model whose state_dict is to be saved.
save_dir (str): The directory where the checkpoint will be saved.
prefix (str): The prefix for the checkpoint file name.
ckpt_num (int): The checkpoint number to be saved.
total_limit (int, optional): The maximum number of checkpoints to keep.
Defaults to None, in which case no checkpoints will be removed.
Raises:
FileNotFoundError: If the save directory does not exist.
ValueError: If the checkpoint number is negative.
OSError: If there is an error saving the checkpoint.
"""

if not osp.exists(save_dir):
raise FileNotFoundError(
f"The save directory {save_dir} does not exist.")

if ckpt_num < 0:
raise ValueError(f"Checkpoint number {ckpt_num} must be non-negative.")

save_path = osp.join(save_dir, f"{prefix}-{ckpt_num}.pth")

if total_limit > 0:
checkpoints = os.listdir(save_dir)
checkpoints = [d for d in checkpoints if d.startswith(prefix)]
checkpoints = sorted(
checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0])
)

if len(checkpoints) >= total_limit:
num_to_remove = len(checkpoints) - total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
print(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
print(
f"Removing checkpoints: {', '.join(removing_checkpoints)}"
)

for removing_checkpoint in removing_checkpoints:
removing_checkpoint_path = osp.join(
save_dir, removing_checkpoint)
try:
os.remove(removing_checkpoint_path)
except OSError as e:
print(
f"Error removing checkpoint {removing_checkpoint_path}: {e}")

state_dict = model.state_dict()
try:
torch.save(state_dict, save_path)
print(f"Checkpoint saved at {save_path}")
except OSError as e:
raise OSError(f"Error saving checkpoint at {save_path}: {e}") from e


def init_output_dir(dir_list: List[str]):
"""
Initialize the output directories.
This function creates the directories specified in the `dir_list`. If a directory already exists, it does nothing.
Args:
dir_list (List[str]): List of directory paths to create.
"""
for path in dir_list:
os.makedirs(path, exist_ok=True)


def load_checkpoint(cfg, save_dir, accelerator):
"""
Load the most recent checkpoint from the specified directory.
This function loads the latest checkpoint from the `save_dir` if the `resume_from_checkpoint` parameter is set to "latest".
If a specific checkpoint is provided in `resume_from_checkpoint`, it loads that checkpoint. If no checkpoint is found,
it starts training from scratch.
Args:
cfg: The configuration object containing training parameters.
save_dir (str): The directory where checkpoints are saved.
accelerator: The accelerator object for distributed training.
Returns:
int: The global step at which to resume training.
"""
if cfg.resume_from_checkpoint != "latest":
resume_dir = cfg.resume_from_checkpoint
else:
resume_dir = save_dir
# Get the most recent checkpoint
dirs = os.listdir(resume_dir)

dirs = [d for d in dirs if d.startswith("checkpoint")]
if len(dirs) > 0:
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1]
accelerator.load_state(os.path.join(resume_dir, path))
accelerator.print(f"Resuming from checkpoint {path}")
global_step = int(path.split("-")[1])
else:
accelerator.print(
f"Could not find checkpoint under {resume_dir}, start training from scratch")
global_step = 0

return global_step


def compute_snr(noise_scheduler, timesteps):
"""
Computes SNR as per
https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/
521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
"""
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = alphas_cumprod**0.5
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5

# Expand the tensors.
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/
# 521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
timesteps
].float()
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)

sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
device=timesteps.device
)[timesteps].float()
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)

# Compute SNR.
snr = (alpha / sigma) ** 2
return snr
Loading

0 comments on commit 0152cd9

Please sign in to comment.