Skip to content

Commit

Permalink
add accelerate to load models with smaller memory footprint (huggingf…
Browse files Browse the repository at this point in the history
…ace#361)

* add accelerate to load models with smaller memory footprint

* remove low_cpu_mem_usage as it is reduntant

* move accelerate init weights context to modelling utils

* add test to ensure results are the same when loading with accelerate

* add tests to ensure ram usage gets lower when using accelerate

* move accelerate logic to single snippet under modelling utils and remove it from configuration utils

* format code using to pass quality check

* fix imports with isor

* add accelerate to test extra deps

* only import accelerate if device_map is set to auto

* move accelerate availability check to diffusers import utils

* format code

Co-authored-by: Patrick von Platen <[email protected]>
  • Loading branch information
piEsposito and patrickvonplaten authored Oct 4, 2022
1 parent 09859a3 commit 4d1cce2
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 32 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
"torch>=1.4",
"torchvision",
"transformers>=4.21.0",
"accelerate>=0.12.0"
]

# this is a lookup table with items like:
Expand Down
101 changes: 69 additions & 32 deletions src/diffusers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch
from torch import Tensor, device

from diffusers.utils import is_accelerate_available
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError
Expand Down Expand Up @@ -293,33 +294,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
from_auto_class = kwargs.pop("_from_auto", False)
torch_dtype = kwargs.pop("torch_dtype", None)
subfolder = kwargs.pop("subfolder", None)
device_map = kwargs.pop("device_map", None)

user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}

# Load config if we don't provide a configuration
config_path = pretrained_model_name_or_path
model, unused_kwargs = cls.from_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
**kwargs,
)

if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
raise ValueError(
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
)
elif torch_dtype is not None:
model = model.to(torch_dtype)

model.register_to_config(_name_or_path=pretrained_model_name_or_path)
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# Load model
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
Expand Down Expand Up @@ -391,25 +372,81 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
)

# restore default dtype
state_dict = load_state_dict(model_file)
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model,
state_dict,
model_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
)

# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
if device_map == "auto":
if is_accelerate_available():
import accelerate
else:
raise ImportError("Please install accelerate via `pip install accelerate`")

with accelerate.init_empty_weights():
model, unused_kwargs = cls.from_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
device_map=device_map,
**kwargs,
)

accelerate.load_checkpoint_and_dispatch(model, model_file, device_map)

loading_info = {
"missing_keys": [],
"unexpected_keys": [],
"mismatched_keys": [],
"error_msgs": [],
}
else:
model, unused_kwargs = cls.from_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
device_map=device_map,
**kwargs,
)

state_dict = load_state_dict(model_file)
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model,
state_dict,
model_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
)

if output_loading_info:
loading_info = {
"missing_keys": missing_keys,
"unexpected_keys": unexpected_keys,
"mismatched_keys": mismatched_keys,
"error_msgs": error_msgs,
}

if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
raise ValueError(
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
)
elif torch_dtype is not None:
model = model.to(torch_dtype)

model.register_to_config(_name_or_path=pretrained_model_name_or_path)

# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
if output_loading_info:
return model, loading_info

return model
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
USE_TF,
USE_TORCH,
DummyObject,
is_accelerate_available,
is_flax_available,
is_inflect_available,
is_modelcards_available,
Expand Down
11 changes: 11 additions & 0 deletions src/diffusers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,13 @@
except importlib_metadata.PackageNotFoundError:
_scipy_available = False

_accelerate_available = importlib.util.find_spec("accelerate") is not None
try:
_accelerate_version = importlib_metadata.version("accelerate")
logger.debug(f"Successfully imported accelerate version {_accelerate_version}")
except importlib_metadata.PackageNotFoundError:
_accelerate_available = False


def is_torch_available():
return _torch_available
Expand Down Expand Up @@ -196,6 +203,10 @@ def is_scipy_available():
return _scipy_available


def is_accelerate_available():
return _accelerate_available


# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
Expand Down
70 changes: 70 additions & 0 deletions tests/test_models_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import math
import tracemalloc
import unittest

import torch
Expand Down Expand Up @@ -133,6 +135,74 @@ def test_from_pretrained_hub(self):

assert image is not None, "Make sure output is not None"

@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
def test_from_pretrained_accelerate(self):
model, _ = UNet2DModel.from_pretrained(
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
)
model.to(torch_device)
image = model(**self.dummy_input).sample

assert image is not None, "Make sure output is not None"

@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
def test_from_pretrained_accelerate_wont_change_results(self):
model_accelerate, _ = UNet2DModel.from_pretrained(
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
)
model_accelerate.to(torch_device)
model_accelerate.eval()

noise = torch.randn(
1,
model_accelerate.config.in_channels,
model_accelerate.config.sample_size,
model_accelerate.config.sample_size,
generator=torch.manual_seed(0),
)
noise = noise.to(torch_device)
time_step = torch.tensor([10] * noise.shape[0]).to(torch_device)

arr_accelerate = model_accelerate(noise, time_step)["sample"]

# two models don't need to stay in the device at the same time
del model_accelerate
torch.cuda.empty_cache()
gc.collect()

model_normal_load, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
model_normal_load.to(torch_device)
model_normal_load.eval()
arr_normal_load = model_normal_load(noise, time_step)["sample"]

assert torch.allclose(arr_accelerate, arr_normal_load, rtol=1e-3)

@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
def test_memory_footprint_gets_reduced(self):
torch.cuda.empty_cache()
gc.collect()

tracemalloc.start()
model_accelerate, _ = UNet2DModel.from_pretrained(
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
)
model_accelerate.to(torch_device)
model_accelerate.eval()
_, peak_accelerate = tracemalloc.get_traced_memory()

del model_accelerate
torch.cuda.empty_cache()
gc.collect()

model_normal_load, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
model_normal_load.to(torch_device)
model_normal_load.eval()
_, peak_normal = tracemalloc.get_traced_memory()

tracemalloc.stop()

assert peak_accelerate < peak_normal

def test_output_pretrained(self):
model = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update")
model.eval()
Expand Down

0 comments on commit 4d1cce2

Please sign in to comment.