forked from xdit-project/xDiT
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Flux.1 performance optimizations on H100 (xdit-project#365)
- Loading branch information
Showing
4 changed files
with
366 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
import functools | ||
from typing import List, Optional, Tuple, Union | ||
|
||
import logging | ||
import time | ||
import torch | ||
import torch.distributed | ||
from diffusers import DiffusionPipeline, FluxPipeline | ||
|
||
from xfuser import xFuserArgs | ||
from xfuser.config import FlexibleArgumentParser | ||
from xfuser.core.distributed import ( | ||
get_world_group, | ||
get_data_parallel_world_size, | ||
get_data_parallel_rank, | ||
get_runtime_state, | ||
get_classifier_free_guidance_world_size, | ||
get_classifier_free_guidance_rank, | ||
get_cfg_group, | ||
get_sequence_parallel_world_size, | ||
get_sequence_parallel_rank, | ||
get_sp_group, | ||
is_dp_last_group, | ||
initialize_runtime_state, | ||
get_pipeline_parallel_world_size, | ||
) | ||
|
||
from xfuser.model_executor.layers.attention_processor_usp import xFuserFluxAttnProcessor2_0USP | ||
|
||
def parallelize_transformer(pipe: DiffusionPipeline): | ||
transformer = pipe.transformer | ||
original_forward = transformer.forward | ||
|
||
@functools.wraps(transformer.__class__.forward) | ||
def new_forward( | ||
self, | ||
hidden_states: torch.Tensor, | ||
encoder_hidden_states: Optional[torch.Tensor] = None, | ||
*args, | ||
timestep: torch.LongTensor = None, | ||
img_ids: torch.Tensor = None, | ||
txt_ids: torch.Tensor = None, | ||
controlnet_block_samples: Optional[List[torch.Tensor]] = None, | ||
controlnet_single_block_samples: Optional[List[torch.Tensor]] = None, | ||
**kwargs, | ||
): | ||
if isinstance(timestep, torch.Tensor) and timestep.ndim != 0 and timestep.shape[0] == hidden_states.shape[0]: | ||
timestep = torch.chunk(timestep, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()] | ||
hidden_states = torch.chunk(hidden_states, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()] | ||
hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()] | ||
encoder_hidden_states = torch.chunk(encoder_hidden_states, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()] | ||
encoder_hidden_states = torch.chunk(encoder_hidden_states, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()] | ||
img_ids = torch.chunk(img_ids, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()] | ||
txt_ids = torch.chunk(txt_ids, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()] | ||
|
||
for block in transformer.transformer_blocks + transformer.single_transformer_blocks: | ||
block.attn.processor = xFuserFluxAttnProcessor2_0USP() | ||
|
||
output = original_forward( | ||
hidden_states, | ||
encoder_hidden_states, | ||
*args, | ||
timestep=timestep, | ||
img_ids=img_ids, | ||
txt_ids=txt_ids, | ||
**kwargs, | ||
) | ||
|
||
return_dict = not isinstance(output, tuple) | ||
sample = output[0] | ||
sample = get_sp_group().all_gather(sample, dim=-2) | ||
sample = get_cfg_group().all_gather(sample, dim=0) | ||
if return_dict: | ||
return output.__class__(sample, *output[1:]) | ||
return (sample, *output[1:]) | ||
|
||
new_forward = new_forward.__get__(transformer) | ||
transformer.forward = new_forward | ||
|
||
|
||
def main(): | ||
parser = FlexibleArgumentParser(description="xFuser Arguments") | ||
args = xFuserArgs.add_cli_args(parser).parse_args() | ||
engine_args = xFuserArgs.from_cli_args(args) | ||
engine_config, input_config = engine_args.create_config() | ||
engine_config.runtime_config.dtype = torch.bfloat16 | ||
local_rank = get_world_group().local_rank | ||
|
||
pipe = FluxPipeline.from_pretrained( | ||
pretrained_model_name_or_path=engine_config.model_config.model, | ||
torch_dtype=torch.bfloat16, | ||
) | ||
|
||
if args.enable_sequential_cpu_offload: | ||
pipe.enable_sequential_cpu_offload(gpu_id=local_rank) | ||
logging.info(f"rank {local_rank} sequential CPU offload enabled") | ||
else: | ||
pipe = pipe.to(f"cuda:{local_rank}") | ||
|
||
parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") | ||
|
||
initialize_runtime_state(pipe, engine_config) | ||
get_runtime_state().set_input_parameters( | ||
height=input_config.height, | ||
width=input_config.width, | ||
batch_size=1, | ||
num_inference_steps=input_config.num_inference_steps, | ||
max_condition_sequence_length=512, | ||
split_text_embed_in_sp=get_pipeline_parallel_world_size() == 1, | ||
) | ||
|
||
parallelize_transformer(pipe) | ||
|
||
if engine_config.runtime_config.use_torch_compile: | ||
torch._inductor.config.reorder_for_compute_comm_overlap = True | ||
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs") | ||
|
||
# warmup | ||
output = pipe( | ||
height=input_config.height, | ||
width=input_config.width, | ||
prompt=input_config.prompt, | ||
num_inference_steps=1, | ||
output_type=input_config.output_type, | ||
generator=torch.Generator(device="cuda").manual_seed(input_config.seed), | ||
).images | ||
|
||
torch.cuda.reset_peak_memory_stats() | ||
start_time = time.time() | ||
|
||
output = pipe( | ||
height=input_config.height, | ||
width=input_config.width, | ||
prompt=input_config.prompt, | ||
num_inference_steps=input_config.num_inference_steps, | ||
output_type=input_config.output_type, | ||
generator=torch.Generator(device="cuda").manual_seed(input_config.seed), | ||
) | ||
end_time = time.time() | ||
elapsed_time = end_time - start_time | ||
peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") | ||
|
||
parallel_info = ( | ||
f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_" | ||
f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_" | ||
f"tp{engine_args.tensor_parallel_degree}_" | ||
f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}" | ||
) | ||
if input_config.output_type == "pil": | ||
dp_group_index = get_data_parallel_rank() | ||
num_dp_groups = get_data_parallel_world_size() | ||
dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups | ||
if is_dp_last_group(): | ||
for i, image in enumerate(output.images): | ||
image_rank = dp_group_index * dp_batch_size + i | ||
image_name = f"flux_result_{parallel_info}_{image_rank}_tc_{engine_args.use_torch_compile}.png" | ||
image.save(f"./results/{image_name}") | ||
print(f"image {i} saved to ./results/{image_name}") | ||
|
||
if get_world_group().rank == get_world_group().world_size - 1: | ||
print( | ||
f"epoch time: {elapsed_time:.2f} sec, parameter memory: {parameter_peak_memory/1e9:.2f} GB, memory: {peak_memory/1e9:.2f} GB" | ||
) | ||
get_runtime_state().destory_distributed_env() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
import torch.distributed | ||
from diffusers.models.attention import Attention | ||
from .attention_processor import Attention | ||
|
||
from diffusers.models.embeddings import apply_rotary_emb | ||
|
||
from xfuser.model_executor.layers.usp import USP | ||
|
||
|
||
class xFuserFluxAttnProcessor2_0USP: | ||
def __call__( | ||
self, | ||
attn: Attention, | ||
hidden_states: torch.FloatTensor, | ||
encoder_hidden_states: torch.FloatTensor = None, | ||
attention_mask: Optional[torch.FloatTensor] = None, | ||
image_rotary_emb: Optional[torch.Tensor] = None, | ||
*args, | ||
**kwargs, | ||
) -> torch.FloatTensor: | ||
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | ||
|
||
# `sample` projections. | ||
query = attn.to_q(hidden_states) | ||
key = attn.to_k(hidden_states) | ||
value = attn.to_v(hidden_states) | ||
|
||
inner_dim = key.shape[-1] | ||
head_dim = inner_dim // attn.heads | ||
|
||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | ||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | ||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | ||
|
||
if attn.norm_q is not None: | ||
query = attn.norm_q(query) | ||
if attn.norm_k is not None: | ||
key = attn.norm_k(key) | ||
|
||
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` | ||
if encoder_hidden_states is not None: | ||
# `context` projections. | ||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) | ||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) | ||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) | ||
|
||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( | ||
batch_size, -1, attn.heads, head_dim | ||
).transpose(1, 2) | ||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( | ||
batch_size, -1, attn.heads, head_dim | ||
).transpose(1, 2) | ||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( | ||
batch_size, -1, attn.heads, head_dim | ||
).transpose(1, 2) | ||
|
||
if attn.norm_added_q is not None: | ||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) | ||
if attn.norm_added_k is not None: | ||
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) | ||
|
||
# attention | ||
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) | ||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) | ||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) | ||
|
||
if image_rotary_emb is not None: | ||
|
||
query = apply_rotary_emb(query, image_rotary_emb) | ||
key = apply_rotary_emb(key, image_rotary_emb) | ||
|
||
hidden_states = USP( | ||
query, key, value, dropout_p=0.0, is_causal=False | ||
) | ||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) | ||
hidden_states = hidden_states.to(query.dtype) | ||
|
||
if encoder_hidden_states is not None: | ||
encoder_hidden_states, hidden_states = ( | ||
hidden_states[:, : encoder_hidden_states.shape[1]], | ||
hidden_states[:, encoder_hidden_states.shape[1] :], | ||
) | ||
|
||
# linear proj | ||
hidden_states = attn.to_out[0](hidden_states) | ||
# dropout | ||
hidden_states = attn.to_out[1](hidden_states) | ||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states) | ||
|
||
return hidden_states, encoder_hidden_states | ||
else: | ||
return hidden_states |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import torch | ||
from torch.nn import functional as F | ||
from torch.distributed.tensor.experimental._attention import _templated_ring_attention | ||
aten = torch.ops.aten | ||
|
||
import torch.distributed._functional_collectives as ft_c | ||
import torch.distributed as dist | ||
|
||
from yunchang.globals import PROCESS_GROUP | ||
from yunchang.comm.all_to_all import SeqAllToAll4D | ||
from xfuser.core.distributed import ( | ||
get_sequence_parallel_world_size, | ||
get_ulysses_parallel_world_size, | ||
get_ring_parallel_world_size, | ||
) | ||
|
||
def ring_attn(query, key, value, dropout_p=0.0, is_causal=False): | ||
out, *_ = _templated_ring_attention( | ||
PROCESS_GROUP.RING_PG, | ||
aten._scaled_dot_product_flash_attention, | ||
query, | ||
key, | ||
value, | ||
dropout_p=dropout_p, | ||
is_causal=is_causal | ||
) | ||
return out | ||
|
||
def _maybe_wait(tensor: torch.Tensor) -> torch.Tensor: | ||
""" | ||
When tracing the code, the result tensor is not an AsyncCollectiveTensor, | ||
so we cannot call ``wait()``. | ||
""" | ||
if isinstance(tensor, ft_c.AsyncCollectiveTensor): | ||
return tensor.wait() | ||
return tensor | ||
|
||
def _sdpa_all_to_all_single(x): | ||
x_shape = x.shape | ||
x = x.flatten() | ||
x = ft_c.all_to_all_single(x, output_split_sizes=None, input_split_sizes=None, group=PROCESS_GROUP.ULYSSES_PG) | ||
x = _maybe_wait(x) | ||
x = x.reshape(x_shape) | ||
return x | ||
|
||
|
||
def _ft_c_input_all_to_all(x): | ||
world_size = get_ulysses_parallel_world_size() | ||
if world_size <= 1: | ||
return x | ||
|
||
assert x.ndim == 4, "x must have 4 dimensions, got {}".format(x.ndim) | ||
b, h, s, d = x.shape | ||
assert h % world_size == 0, "h must be divisible by world_size, got {} and {}".format(h, world_size) | ||
|
||
x = x.permute(1, 0, 2, 3).contiguous() | ||
x = _sdpa_all_to_all_single(x) | ||
x = x.reshape(world_size, h // world_size, b, -1, d).permute(2, 1, 0, 3, 4).reshape(b, h // world_size, -1, d) | ||
return x | ||
|
||
|
||
def _ft_c_output_all_to_all(x): | ||
world_size = get_ulysses_parallel_world_size() | ||
if world_size <= 1: | ||
return x | ||
|
||
assert x.ndim == 4, "x must have 4 dimensions, got {}".format(x.ndim) | ||
b, h, s, d = x.shape | ||
assert s % world_size == 0, "s must be divisible by world_size, got {} and {}".format(s, world_size) | ||
|
||
x = x.permute(2, 0, 1, 3).contiguous() | ||
x = _sdpa_all_to_all_single(x) | ||
x = x.reshape(world_size, s // world_size, b, -1, d).permute(2, 0, 3, 1, 4).reshape(b, -1, s // world_size, d) | ||
return x | ||
|
||
|
||
def USP(query, key, value, dropout_p=0.0, is_causal=False): | ||
if get_sequence_parallel_world_size() == 1: | ||
out = F.scaled_dot_product_attention( | ||
query, key, value, dropout_p=dropout_p, is_causal=is_causal | ||
) | ||
elif get_ulysses_parallel_world_size() == 1: | ||
out = ring_attn(query, key, value, dropout_p=dropout_p, is_causal=is_causal) | ||
elif get_ulysses_parallel_world_size() > 1: | ||
query = _ft_c_input_all_to_all(query) | ||
key = _ft_c_input_all_to_all(key) | ||
value = _ft_c_input_all_to_all(value) | ||
|
||
if get_ring_parallel_world_size() == 1: | ||
out = F.scaled_dot_product_attention( | ||
query, key, value, dropout_p=dropout_p, is_causal=is_causal | ||
) | ||
else: | ||
out = ring_attn(query, key, value, dropout_p=dropout_p, is_causal=is_causal) | ||
|
||
out = _ft_c_output_all_to_all(out) | ||
|
||
return out |