Skip to content

Commit

Permalink
Correctly handle creating model index json files when setting compile…
Browse files Browse the repository at this point in the history
…d modules in pipelines. (huggingface#6436)

update
  • Loading branch information
DN6 authored Jan 5, 2024
1 parent 6ef2b8a commit 86a2676
Showing 1 changed file with 32 additions and 28 deletions.
60 changes: 32 additions & 28 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,36 @@ def load_sub_model(
return loaded_sub_model


def _fetch_class_library_tuple(module):
# import it here to avoid circular import
diffusers_module = importlib.import_module(__name__.split(".")[0])
pipelines = getattr(diffusers_module, "pipelines")

# register the config from the original module, not the dynamo compiled one
not_compiled_module = _unwrap_model(module)
library = not_compiled_module.__module__.split(".")[0]

# check if the module is a pipeline module
module_path_items = not_compiled_module.__module__.split(".")
pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None

path = not_compiled_module.__module__.split(".")
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)

# if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline
# folder so we set the library to module name.
if is_pipeline_module:
library = pipeline_dir
elif library not in LOADABLE_CLASSES:
library = not_compiled_module.__module__

# retrieve class_name
class_name = not_compiled_module.__class__.__name__

return (library, class_name)


class DiffusionPipeline(ConfigMixin, PushToHubMixin):
r"""
Base class for all pipelines.
Expand All @@ -556,38 +586,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
_is_onnx = False

def register_modules(self, **kwargs):
# import it here to avoid circular import
diffusers_module = importlib.import_module(__name__.split(".")[0])
pipelines = getattr(diffusers_module, "pipelines")

for name, module in kwargs.items():
# retrieve library
if module is None or isinstance(module, (tuple, list)) and module[0] is None:
register_dict = {name: (None, None)}
else:
# register the config from the original module, not the dynamo compiled one
not_compiled_module = _unwrap_model(module)

library = not_compiled_module.__module__.split(".")[0]

# check if the module is a pipeline module
module_path_items = not_compiled_module.__module__.split(".")
pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None

path = not_compiled_module.__module__.split(".")
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)

# if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline
# folder so we set the library to module name.
if is_pipeline_module:
library = pipeline_dir
elif library not in LOADABLE_CLASSES:
library = not_compiled_module.__module__

# retrieve class_name
class_name = not_compiled_module.__class__.__name__

library, class_name = _fetch_class_library_tuple(module)
register_dict = {name: (library, class_name)}

# save model index config
Expand All @@ -601,7 +605,7 @@ def __setattr__(self, name: str, value: Any):
# We need to overwrite the config if name exists in config
if isinstance(getattr(self.config, name), (tuple, list)):
if value is not None and self.config[name][0] is not None:
class_library_tuple = (value.__module__.split(".")[0], value.__class__.__name__)
class_library_tuple = _fetch_class_library_tuple(value)
else:
class_library_tuple = (None, None)

Expand Down

0 comments on commit 86a2676

Please sign in to comment.