From f4f854138dc68ec5ae8a30aac2462bc6b857266e Mon Sep 17 00:00:00 2001 From: ethansmith2000 <98723285+ethansmith2000@users.noreply.github.com> Date: Mon, 7 Aug 2023 09:26:54 -0400 Subject: [PATCH] grad checkpointing (#4474) * grad checkpointing * fix make fix-copies * fix --------- Co-authored-by: Patrick von Platen --- src/diffusers/models/unet_2d_blocks.py | 97 ++++++++----------- .../versatile_diffusion/modeling_text_unet.py | 45 ++++----- 2 files changed, 59 insertions(+), 83 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 6f3037d624f9..e894628462ef 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -648,16 +648,13 @@ def custom_forward(*inputs): return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), + hidden_states = attn( hidden_states, - encoder_hidden_states, - None, # timestep - None, # class_labels - cross_attention_kwargs, - attention_mask, - encoder_attention_mask, - **ckpt_kwargs, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, )[0] hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), @@ -1035,16 +1032,13 @@ def custom_forward(*inputs): temb, **ckpt_kwargs, ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), + hidden_states = attn( hidden_states, - encoder_hidden_states, - None, # timestep - None, # class_labels - cross_attention_kwargs, - attention_mask, - encoder_attention_mask, - **ckpt_kwargs, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, )[0] else: hidden_states = resnet(hidden_states, temb) @@ -1711,13 +1705,12 @@ def custom_forward(*inputs): return custom_forward hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), + hidden_states = attn( hidden_states, - encoder_hidden_states, - mask, - cross_attention_kwargs, - )[0] + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) else: hidden_states = resnet(hidden_states, temb) @@ -1912,15 +1905,13 @@ def custom_forward(*inputs): temb, **ckpt_kwargs, ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), + hidden_states = attn( hidden_states, - encoder_hidden_states, - temb, - attention_mask, - cross_attention_kwargs, - encoder_attention_mask, - **ckpt_kwargs, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, ) else: hidden_states = resnet(hidden_states, temb) @@ -2173,16 +2164,13 @@ def custom_forward(*inputs): temb, **ckpt_kwargs, ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), + hidden_states = attn( hidden_states, - encoder_hidden_states, - None, # timestep - None, # class_labels - cross_attention_kwargs, - attention_mask, - encoder_attention_mask, - **ckpt_kwargs, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, )[0] else: hidden_states = resnet(hidden_states, temb) @@ -2872,13 +2860,12 @@ def custom_forward(*inputs): return custom_forward hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), + hidden_states = attn( hidden_states, - encoder_hidden_states, - mask, - cross_attention_kwargs, - )[0] + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) else: hidden_states = resnet(hidden_states, temb) @@ -3094,16 +3081,14 @@ def custom_forward(*inputs): temb, **ckpt_kwargs, ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), + hidden_states = attn( hidden_states, - encoder_hidden_states, - temb, - attention_mask, - cross_attention_kwargs, - encoder_attention_mask, - **ckpt_kwargs, - )[0] + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index adb41a8dfd07..fe9455a19bf0 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1429,16 +1429,13 @@ def custom_forward(*inputs): temb, **ckpt_kwargs, ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), + hidden_states = attn( hidden_states, - encoder_hidden_states, - None, # timestep - None, # class_labels - cross_attention_kwargs, - attention_mask, - encoder_attention_mask, - **ckpt_kwargs, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, )[0] else: hidden_states = resnet(hidden_states, temb) @@ -1668,16 +1665,13 @@ def custom_forward(*inputs): temb, **ckpt_kwargs, ) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), + hidden_states = attn( hidden_states, - encoder_hidden_states, - None, # timestep - None, # class_labels - cross_attention_kwargs, - attention_mask, - encoder_attention_mask, - **ckpt_kwargs, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, )[0] else: hidden_states = resnet(hidden_states, temb) @@ -1809,16 +1803,13 @@ def custom_forward(*inputs): return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), + hidden_states = attn( hidden_states, - encoder_hidden_states, - None, # timestep - None, # class_labels - cross_attention_kwargs, - attention_mask, - encoder_attention_mask, - **ckpt_kwargs, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, )[0] hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet),