Skip to content

Commit

Permalink
Flux.1 performance optimizations on H100 (xdit-project#365)
Browse files Browse the repository at this point in the history
  • Loading branch information
xibosun authored Nov 28, 2024
1 parent a7bd749 commit c6c0f8a
Show file tree
Hide file tree
Showing 4 changed files with 366 additions and 2 deletions.
168 changes: 168 additions & 0 deletions examples/flux_usp_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import functools
from typing import List, Optional, Tuple, Union

import logging
import time
import torch
import torch.distributed
from diffusers import DiffusionPipeline, FluxPipeline

from xfuser import xFuserArgs
from xfuser.config import FlexibleArgumentParser
from xfuser.core.distributed import (
get_world_group,
get_data_parallel_world_size,
get_data_parallel_rank,
get_runtime_state,
get_classifier_free_guidance_world_size,
get_classifier_free_guidance_rank,
get_cfg_group,
get_sequence_parallel_world_size,
get_sequence_parallel_rank,
get_sp_group,
is_dp_last_group,
initialize_runtime_state,
get_pipeline_parallel_world_size,
)

from xfuser.model_executor.layers.attention_processor_usp import xFuserFluxAttnProcessor2_0USP

def parallelize_transformer(pipe: DiffusionPipeline):
transformer = pipe.transformer
original_forward = transformer.forward

@functools.wraps(transformer.__class__.forward)
def new_forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
*args,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
controlnet_block_samples: Optional[List[torch.Tensor]] = None,
controlnet_single_block_samples: Optional[List[torch.Tensor]] = None,
**kwargs,
):
if isinstance(timestep, torch.Tensor) and timestep.ndim != 0 and timestep.shape[0] == hidden_states.shape[0]:
timestep = torch.chunk(timestep, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()]
hidden_states = torch.chunk(hidden_states, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()]
hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()]
encoder_hidden_states = torch.chunk(encoder_hidden_states, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()]
encoder_hidden_states = torch.chunk(encoder_hidden_states, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()]
img_ids = torch.chunk(img_ids, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()]
txt_ids = torch.chunk(txt_ids, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()]

for block in transformer.transformer_blocks + transformer.single_transformer_blocks:
block.attn.processor = xFuserFluxAttnProcessor2_0USP()

output = original_forward(
hidden_states,
encoder_hidden_states,
*args,
timestep=timestep,
img_ids=img_ids,
txt_ids=txt_ids,
**kwargs,
)

return_dict = not isinstance(output, tuple)
sample = output[0]
sample = get_sp_group().all_gather(sample, dim=-2)
sample = get_cfg_group().all_gather(sample, dim=0)
if return_dict:
return output.__class__(sample, *output[1:])
return (sample, *output[1:])

new_forward = new_forward.__get__(transformer)
transformer.forward = new_forward


def main():
parser = FlexibleArgumentParser(description="xFuser Arguments")
args = xFuserArgs.add_cli_args(parser).parse_args()
engine_args = xFuserArgs.from_cli_args(args)
engine_config, input_config = engine_args.create_config()
engine_config.runtime_config.dtype = torch.bfloat16
local_rank = get_world_group().local_rank

pipe = FluxPipeline.from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
torch_dtype=torch.bfloat16,
)

if args.enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload(gpu_id=local_rank)
logging.info(f"rank {local_rank} sequential CPU offload enabled")
else:
pipe = pipe.to(f"cuda:{local_rank}")

parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")

initialize_runtime_state(pipe, engine_config)
get_runtime_state().set_input_parameters(
height=input_config.height,
width=input_config.width,
batch_size=1,
num_inference_steps=input_config.num_inference_steps,
max_condition_sequence_length=512,
split_text_embed_in_sp=get_pipeline_parallel_world_size() == 1,
)

parallelize_transformer(pipe)

if engine_config.runtime_config.use_torch_compile:
torch._inductor.config.reorder_for_compute_comm_overlap = True
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs")

# warmup
output = pipe(
height=input_config.height,
width=input_config.width,
prompt=input_config.prompt,
num_inference_steps=1,
output_type=input_config.output_type,
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
).images

torch.cuda.reset_peak_memory_stats()
start_time = time.time()

output = pipe(
height=input_config.height,
width=input_config.width,
prompt=input_config.prompt,
num_inference_steps=input_config.num_inference_steps,
output_type=input_config.output_type,
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
)
end_time = time.time()
elapsed_time = end_time - start_time
peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")

parallel_info = (
f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_"
f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_"
f"tp{engine_args.tensor_parallel_degree}_"
f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}"
)
if input_config.output_type == "pil":
dp_group_index = get_data_parallel_rank()
num_dp_groups = get_data_parallel_world_size()
dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups
if is_dp_last_group():
for i, image in enumerate(output.images):
image_rank = dp_group_index * dp_batch_size + i
image_name = f"flux_result_{parallel_info}_{image_rank}_tc_{engine_args.use_torch_compile}.png"
image.save(f"./results/{image_name}")
print(f"image {i} saved to ./results/{image_name}")

if get_world_group().rank == get_world_group().world_size - 1:
print(
f"epoch time: {elapsed_time:.2f} sec, parameter memory: {parameter_peak_memory/1e9:.2f} GB, memory: {peak_memory/1e9:.2f} GB"
)
get_runtime_state().destory_distributed_env()


if __name__ == "__main__":
main()
7 changes: 5 additions & 2 deletions xfuser/core/distributed/group_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,15 +223,18 @@ def all_gather(
# Convert negative dim to positive.
dim += input_.dim()
# Allocate output tensor.
input_size = input_.size()
input_size = list(input_.size())
input_size[0] *= world_size
output_tensor = torch.empty(
(world_size,) + input_size, dtype=input_.dtype, device=input_.device
input_size, dtype=input_.dtype, device=input_.device
)
# All-gather.
torch.distributed.all_gather_into_tensor(
output_tensor, input_, group=self.device_group
)
if dim != 0:
input_size[0] //= world_size
output_tensor = output_tensor.reshape([world_size, ] + input_size)
output_tensor = output_tensor.movedim(0, dim)

if separate_tensors:
Expand Down
95 changes: 95 additions & 0 deletions xfuser/model_executor/layers/attention_processor_usp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from typing import Optional

import torch
import torch.distributed
from diffusers.models.attention import Attention
from .attention_processor import Attention

from diffusers.models.embeddings import apply_rotary_emb

from xfuser.model_executor.layers.usp import USP


class xFuserFluxAttnProcessor2_0USP:
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape

# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)

encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)

if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)

# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)

if image_rotary_emb is not None:

query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)

hidden_states = USP(
query, key, value, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

return hidden_states, encoder_hidden_states
else:
return hidden_states
98 changes: 98 additions & 0 deletions xfuser/model_executor/layers/usp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import torch
from torch.nn import functional as F
from torch.distributed.tensor.experimental._attention import _templated_ring_attention
aten = torch.ops.aten

import torch.distributed._functional_collectives as ft_c
import torch.distributed as dist

from yunchang.globals import PROCESS_GROUP
from yunchang.comm.all_to_all import SeqAllToAll4D
from xfuser.core.distributed import (
get_sequence_parallel_world_size,
get_ulysses_parallel_world_size,
get_ring_parallel_world_size,
)

def ring_attn(query, key, value, dropout_p=0.0, is_causal=False):
out, *_ = _templated_ring_attention(
PROCESS_GROUP.RING_PG,
aten._scaled_dot_product_flash_attention,
query,
key,
value,
dropout_p=dropout_p,
is_causal=is_causal
)
return out

def _maybe_wait(tensor: torch.Tensor) -> torch.Tensor:
"""
When tracing the code, the result tensor is not an AsyncCollectiveTensor,
so we cannot call ``wait()``.
"""
if isinstance(tensor, ft_c.AsyncCollectiveTensor):
return tensor.wait()
return tensor

def _sdpa_all_to_all_single(x):
x_shape = x.shape
x = x.flatten()
x = ft_c.all_to_all_single(x, output_split_sizes=None, input_split_sizes=None, group=PROCESS_GROUP.ULYSSES_PG)
x = _maybe_wait(x)
x = x.reshape(x_shape)
return x


def _ft_c_input_all_to_all(x):
world_size = get_ulysses_parallel_world_size()
if world_size <= 1:
return x

assert x.ndim == 4, "x must have 4 dimensions, got {}".format(x.ndim)
b, h, s, d = x.shape
assert h % world_size == 0, "h must be divisible by world_size, got {} and {}".format(h, world_size)

x = x.permute(1, 0, 2, 3).contiguous()
x = _sdpa_all_to_all_single(x)
x = x.reshape(world_size, h // world_size, b, -1, d).permute(2, 1, 0, 3, 4).reshape(b, h // world_size, -1, d)
return x


def _ft_c_output_all_to_all(x):
world_size = get_ulysses_parallel_world_size()
if world_size <= 1:
return x

assert x.ndim == 4, "x must have 4 dimensions, got {}".format(x.ndim)
b, h, s, d = x.shape
assert s % world_size == 0, "s must be divisible by world_size, got {} and {}".format(s, world_size)

x = x.permute(2, 0, 1, 3).contiguous()
x = _sdpa_all_to_all_single(x)
x = x.reshape(world_size, s // world_size, b, -1, d).permute(2, 0, 3, 1, 4).reshape(b, -1, s // world_size, d)
return x


def USP(query, key, value, dropout_p=0.0, is_causal=False):
if get_sequence_parallel_world_size() == 1:
out = F.scaled_dot_product_attention(
query, key, value, dropout_p=dropout_p, is_causal=is_causal
)
elif get_ulysses_parallel_world_size() == 1:
out = ring_attn(query, key, value, dropout_p=dropout_p, is_causal=is_causal)
elif get_ulysses_parallel_world_size() > 1:
query = _ft_c_input_all_to_all(query)
key = _ft_c_input_all_to_all(key)
value = _ft_c_input_all_to_all(value)

if get_ring_parallel_world_size() == 1:
out = F.scaled_dot_product_attention(
query, key, value, dropout_p=dropout_p, is_causal=is_causal
)
else:
out = ring_attn(query, key, value, dropout_p=dropout_p, is_causal=is_causal)

out = _ft_c_output_all_to_all(out)

return out

0 comments on commit c6c0f8a

Please sign in to comment.