Skip to content

Commit

Permalink
[Tests] clean up and refactor gradient checkpointing tests (huggingfa…
Browse files Browse the repository at this point in the history
…ce#9494)

* check.

* fixes

* fixes

* updates

* fixes

* fixes
  • Loading branch information
sayakpaul authored Oct 31, 2024
1 parent 8ce37ab commit 4adf6af
Show file tree
Hide file tree
Showing 15 changed files with 180 additions and 273 deletions.
109 changes: 25 additions & 84 deletions tests/models/autoencoders/test_models_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
load_hf_numpy,
require_torch_accelerator,
require_torch_accelerator_with_fp16,
require_torch_accelerator_with_training,
require_torch_gpu,
skip_mps,
slow,
Expand Down Expand Up @@ -170,52 +169,17 @@ def prepare_init_args_and_inputs_for_common(self):
inputs_dict = self.dummy_input
return init_dict, inputs_dict

@unittest.skip("Not tested.")
def test_forward_signature(self):
pass

@unittest.skip("Not tested.")
def test_training(self):
pass

@require_torch_accelerator_with_training
def test_gradient_checkpointing(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)

assert not model.is_gradient_checkpointing and model.training

out = model(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model.zero_grad()

labels = torch.randn_like(out)
loss = (out - labels).mean()
loss.backward()

# re-instantiate the model now enabling gradient checkpointing
model_2 = self.model_class(**init_dict)
# clone model
model_2.load_state_dict(model.state_dict())
model_2.to(torch_device)
model_2.enable_gradient_checkpointing()

assert model_2.is_gradient_checkpointing and model_2.training

out_2 = model_2(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model_2.zero_grad()
loss_2 = (out_2 - labels).mean()
loss_2.backward()

# compare the output and parameters gradients
self.assertTrue((loss - loss_2).abs() < 1e-5)
named_params = dict(model.named_parameters())
named_params_2 = dict(model_2.named_parameters())
for name, param in named_params.items():
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
def test_gradient_checkpointing_is_applied(self):
expected_set = {"Decoder", "Encoder"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

def test_from_pretrained_hub(self):
model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
Expand Down Expand Up @@ -329,9 +293,11 @@ def prepare_init_args_and_inputs_for_common(self):
inputs_dict = self.dummy_input
return init_dict, inputs_dict

@unittest.skip("Not tested.")
def test_forward_signature(self):
pass

@unittest.skip("Not tested.")
def test_forward_with_norm_groups(self):
pass

Expand Down Expand Up @@ -364,9 +330,20 @@ def prepare_init_args_and_inputs_for_common(self):
inputs_dict = self.dummy_input
return init_dict, inputs_dict

@unittest.skip("Not tested.")
def test_outputs_equivalence(self):
pass

def test_gradient_checkpointing_is_applied(self):
expected_set = {"DecoderTiny", "EncoderTiny"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

@unittest.skip(
"Gradient checkpointing is supported but this test doesn't apply to this class because it's forward is a bit different from the rest."
)
def test_effective_gradient_checkpointing(self):
pass


class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
model_class = ConsistencyDecoderVAE
Expand Down Expand Up @@ -443,55 +420,17 @@ def prepare_init_args_and_inputs_for_common(self):
inputs_dict = self.dummy_input
return init_dict, inputs_dict

@unittest.skip("Not tested.")
def test_forward_signature(self):
pass

@unittest.skip("Not tested.")
def test_training(self):
pass

@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
def test_gradient_checkpointing(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)

assert not model.is_gradient_checkpointing and model.training

out = model(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model.zero_grad()

labels = torch.randn_like(out)
loss = (out - labels).mean()
loss.backward()

# re-instantiate the model now enabling gradient checkpointing
model_2 = self.model_class(**init_dict)
# clone model
model_2.load_state_dict(model.state_dict())
model_2.to(torch_device)
model_2.enable_gradient_checkpointing()

assert model_2.is_gradient_checkpointing and model_2.training

out_2 = model_2(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model_2.zero_grad()
loss_2 = (out_2 - labels).mean()
loss_2.backward()

# compare the output and parameters gradients
self.assertTrue((loss - loss_2).abs() < 1e-5)
named_params = dict(model.named_parameters())
named_params_2 = dict(model_2.named_parameters())
for name, param in named_params.items():
if "post_quant_conv" in name:
continue

self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
def test_gradient_checkpointing_is_applied(self):
expected_set = {"Encoder", "TemporalDecoder"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)


class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
Expand Down Expand Up @@ -522,9 +461,11 @@ def prepare_init_args_and_inputs_for_common(self):
inputs_dict = self.dummy_input
return init_dict, inputs_dict

@unittest.skip("Not tested.")
def test_forward_signature(self):
pass

@unittest.skip("Not tested.")
def test_forward_with_norm_groups(self):
pass

Expand Down
97 changes: 97 additions & 0 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import inspect
import json
import os
Expand Down Expand Up @@ -57,6 +58,7 @@
require_torch_gpu,
require_torch_multi_gpu,
run_test_in_subprocess,
torch_all_close,
torch_device,
)

Expand Down Expand Up @@ -785,6 +787,101 @@ def test_enable_disable_gradient_checkpointing(self):
model.disable_gradient_checkpointing()
self.assertFalse(model.is_gradient_checkpointing)

@require_torch_accelerator_with_training
def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5):
if not self.model_class._supports_gradient_checkpointing:
return # Skip test if model does not support gradient checkpointing

# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
inputs_dict_copy = copy.deepcopy(inputs_dict)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)

assert not model.is_gradient_checkpointing and model.training

out = model(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model.zero_grad()

labels = torch.randn_like(out)
loss = (out - labels).mean()
loss.backward()

# re-instantiate the model now enabling gradient checkpointing
torch.manual_seed(0)
model_2 = self.model_class(**init_dict)
# clone model
model_2.load_state_dict(model.state_dict())
model_2.to(torch_device)
model_2.enable_gradient_checkpointing()

assert model_2.is_gradient_checkpointing and model_2.training

out_2 = model_2(**inputs_dict_copy).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model_2.zero_grad()
loss_2 = (out_2 - labels).mean()
loss_2.backward()

# compare the output and parameters gradients
self.assertTrue((loss - loss_2).abs() < loss_tolerance)
named_params = dict(model.named_parameters())
named_params_2 = dict(model_2.named_parameters())

for name, param in named_params.items():
if "post_quant_conv" in name:
continue
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol))

@unittest.skipIf(torch_device == "mps", "This test is not supported for MPS devices.")
def test_gradient_checkpointing_is_applied(
self, expected_set=None, attention_head_dim=None, num_attention_heads=None, block_out_channels=None
):
if not self.model_class._supports_gradient_checkpointing:
return # Skip test if model does not support gradient checkpointing
if self.model_class.__name__ in [
"UNetSpatioTemporalConditionModel",
"AutoencoderKLTemporalDecoder",
]:
return

init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

if attention_head_dim is not None:
init_dict["attention_head_dim"] = attention_head_dim
if num_attention_heads is not None:
init_dict["num_attention_heads"] = num_attention_heads
if block_out_channels is not None:
init_dict["block_out_channels"] = block_out_channels

model_class_copy = copy.copy(self.model_class)

modules_with_gc_enabled = {}

# now monkey patch the following function:
# def _set_gradient_checkpointing(self, module, value=False):
# if hasattr(module, "gradient_checkpointing"):
# module.gradient_checkpointing = value

def _set_gradient_checkpointing_new(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
modules_with_gc_enabled[module.__class__.__name__] = True

model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new

model = model_class_copy(**init_dict)
model.enable_gradient_checkpointing()

print(f"{set(modules_with_gc_enabled.keys())=}, {expected_set=}")

assert set(modules_with_gc_enabled.keys()) == expected_set
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"

def test_deprecated_kwargs(self):
has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters
has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0
Expand Down
7 changes: 7 additions & 0 deletions tests/models/transformers/test_models_dit_transformer2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ def test_correct_class_remapping_from_dict_config(self):
model = Transformer2DModel.from_config(init_dict)
assert isinstance(model, DiTTransformer2DModel)

def test_gradient_checkpointing_is_applied(self):
expected_set = {"DiTTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

def test_effective_gradient_checkpointing(self):
super().test_effective_gradient_checkpointing(loss_tolerance=1e-4)

def test_correct_class_remapping_from_pretrained_config(self):
config = DiTTransformer2DModel.load_config("facebook/DiT-XL-2-256", subfolder="transformer")
model = Transformer2DModel.from_config(config)
Expand Down
4 changes: 4 additions & 0 deletions tests/models/transformers/test_models_pixart_transformer2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ def test_output(self):
expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape
)

def test_gradient_checkpointing_is_applied(self):
expected_set = {"PixArtTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

def test_correct_class_remapping_from_dict_config(self):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = Transformer2DModel.from_config(init_dict)
Expand Down
4 changes: 4 additions & 0 deletions tests/models/transformers/test_models_transformer_allegro.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,7 @@ def prepare_init_args_and_inputs_for_common(self):
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict

def test_gradient_checkpointing_is_applied(self):
expected_set = {"AllegroTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def prepare_init_args_and_inputs_for_common(self):
inputs_dict = self.dummy_input
return init_dict, inputs_dict

def test_gradient_checkpointing_is_applied(self):
expected_set = {"AuraFlowTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

@unittest.skip("AuraFlowTransformer2DModel uses its own dedicated attention processor. This test does not apply")
def test_set_attn_processor_for_determinism(self):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,7 @@ def prepare_init_args_and_inputs_for_common(self):
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict

def test_gradient_checkpointing_is_applied(self):
expected_set = {"CogVideoXTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,7 @@ def prepare_init_args_and_inputs_for_common(self):
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict

def test_gradient_checkpointing_is_applied(self):
expected_set = {"CogView3PlusTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
4 changes: 4 additions & 0 deletions tests/models/transformers/test_models_transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,7 @@ def test_deprecated_inputs_img_txt_ids_3d(self):
torch.allclose(output_1, output_2, atol=1e-5),
msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
)

def test_gradient_checkpointing_is_applied(self):
expected_set = {"FluxTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
4 changes: 4 additions & 0 deletions tests/models/transformers/test_models_transformer_latte.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,7 @@ def test_output(self):
super().test_output(
expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape
)

def test_gradient_checkpointing_is_applied(self):
expected_set = {"LatteTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
8 changes: 8 additions & 0 deletions tests/models/transformers/test_models_transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ def prepare_init_args_and_inputs_for_common(self):
def test_set_attn_processor_for_determinism(self):
pass

def test_gradient_checkpointing_is_applied(self):
expected_set = {"SD3Transformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)


class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = SD3Transformer2DModel
Expand Down Expand Up @@ -139,3 +143,7 @@ def prepare_init_args_and_inputs_for_common(self):
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
def test_set_attn_processor_for_determinism(self):
pass

def test_gradient_checkpointing_is_applied(self):
expected_set = {"SD3Transformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
Loading

0 comments on commit 4adf6af

Please sign in to comment.