Skip to content

Commit

Permalink
grad checkpointing (huggingface#4474)
Browse files Browse the repository at this point in the history
* grad checkpointing

* fix make fix-copies

* fix

---------

Co-authored-by: Patrick von Platen <[email protected]>
  • Loading branch information
ethansmith2000 and patrickvonplaten authored Aug 7, 2023
1 parent e1b5b8b commit f4f8541
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 83 deletions.
97 changes: 41 additions & 56 deletions src/diffusers/models/unet_2d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,16 +648,13 @@ def custom_forward(*inputs):
return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states = attn(
hidden_states,
encoder_hidden_states,
None, # timestep
None, # class_labels
cross_attention_kwargs,
attention_mask,
encoder_attention_mask,
**ckpt_kwargs,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
Expand Down Expand Up @@ -1035,16 +1032,13 @@ def custom_forward(*inputs):
temb,
**ckpt_kwargs,
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states = attn(
hidden_states,
encoder_hidden_states,
None, # timestep
None, # class_labels
cross_attention_kwargs,
attention_mask,
encoder_attention_mask,
**ckpt_kwargs,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
Expand Down Expand Up @@ -1711,13 +1705,12 @@ def custom_forward(*inputs):
return custom_forward

hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states = attn(
hidden_states,
encoder_hidden_states,
mask,
cross_attention_kwargs,
)[0]
encoder_hidden_states=encoder_hidden_states,
attention_mask=mask,
**cross_attention_kwargs,
)
else:
hidden_states = resnet(hidden_states, temb)

Expand Down Expand Up @@ -1912,15 +1905,13 @@ def custom_forward(*inputs):
temb,
**ckpt_kwargs,
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states = attn(
hidden_states,
encoder_hidden_states,
temb,
attention_mask,
cross_attention_kwargs,
encoder_attention_mask,
**ckpt_kwargs,
encoder_hidden_states=encoder_hidden_states,
emb=temb,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
)
else:
hidden_states = resnet(hidden_states, temb)
Expand Down Expand Up @@ -2173,16 +2164,13 @@ def custom_forward(*inputs):
temb,
**ckpt_kwargs,
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states = attn(
hidden_states,
encoder_hidden_states,
None, # timestep
None, # class_labels
cross_attention_kwargs,
attention_mask,
encoder_attention_mask,
**ckpt_kwargs,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
Expand Down Expand Up @@ -2872,13 +2860,12 @@ def custom_forward(*inputs):
return custom_forward

hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states = attn(
hidden_states,
encoder_hidden_states,
mask,
cross_attention_kwargs,
)[0]
encoder_hidden_states=encoder_hidden_states,
attention_mask=mask,
**cross_attention_kwargs,
)
else:
hidden_states = resnet(hidden_states, temb)

Expand Down Expand Up @@ -3094,16 +3081,14 @@ def custom_forward(*inputs):
temb,
**ckpt_kwargs,
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states = attn(
hidden_states,
encoder_hidden_states,
temb,
attention_mask,
cross_attention_kwargs,
encoder_attention_mask,
**ckpt_kwargs,
)[0]
encoder_hidden_states=encoder_hidden_states,
emb=temb,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(
Expand Down
45 changes: 18 additions & 27 deletions src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1429,16 +1429,13 @@ def custom_forward(*inputs):
temb,
**ckpt_kwargs,
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states = attn(
hidden_states,
encoder_hidden_states,
None, # timestep
None, # class_labels
cross_attention_kwargs,
attention_mask,
encoder_attention_mask,
**ckpt_kwargs,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
Expand Down Expand Up @@ -1668,16 +1665,13 @@ def custom_forward(*inputs):
temb,
**ckpt_kwargs,
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states = attn(
hidden_states,
encoder_hidden_states,
None, # timestep
None, # class_labels
cross_attention_kwargs,
attention_mask,
encoder_attention_mask,
**ckpt_kwargs,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
Expand Down Expand Up @@ -1809,16 +1803,13 @@ def custom_forward(*inputs):
return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states = attn(
hidden_states,
encoder_hidden_states,
None, # timestep
None, # class_labels
cross_attention_kwargs,
attention_mask,
encoder_attention_mask,
**ckpt_kwargs,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
Expand Down

0 comments on commit f4f8541

Please sign in to comment.