From 86a26761acbc9f152ad006493ebe1345fc095d31 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 5 Jan 2024 18:02:09 +0530 Subject: [PATCH] Correctly handle creating model index json files when setting compiled modules in pipelines. (#6436) update --- src/diffusers/pipelines/pipeline_utils.py | 60 ++++++++++++----------- 1 file changed, 32 insertions(+), 28 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 3054c491fd1f..de5dea679ee9 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -530,6 +530,36 @@ def load_sub_model( return loaded_sub_model +def _fetch_class_library_tuple(module): + # import it here to avoid circular import + diffusers_module = importlib.import_module(__name__.split(".")[0]) + pipelines = getattr(diffusers_module, "pipelines") + + # register the config from the original module, not the dynamo compiled one + not_compiled_module = _unwrap_model(module) + library = not_compiled_module.__module__.split(".")[0] + + # check if the module is a pipeline module + module_path_items = not_compiled_module.__module__.split(".") + pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None + + path = not_compiled_module.__module__.split(".") + is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) + + # if library is not in LOADABLE_CLASSES, then it is a custom module. + # Or if it's a pipeline module, then the module is inside the pipeline + # folder so we set the library to module name. + if is_pipeline_module: + library = pipeline_dir + elif library not in LOADABLE_CLASSES: + library = not_compiled_module.__module__ + + # retrieve class_name + class_name = not_compiled_module.__class__.__name__ + + return (library, class_name) + + class DiffusionPipeline(ConfigMixin, PushToHubMixin): r""" Base class for all pipelines. @@ -556,38 +586,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): _is_onnx = False def register_modules(self, **kwargs): - # import it here to avoid circular import - diffusers_module = importlib.import_module(__name__.split(".")[0]) - pipelines = getattr(diffusers_module, "pipelines") - for name, module in kwargs.items(): # retrieve library if module is None or isinstance(module, (tuple, list)) and module[0] is None: register_dict = {name: (None, None)} else: - # register the config from the original module, not the dynamo compiled one - not_compiled_module = _unwrap_model(module) - - library = not_compiled_module.__module__.split(".")[0] - - # check if the module is a pipeline module - module_path_items = not_compiled_module.__module__.split(".") - pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None - - path = not_compiled_module.__module__.split(".") - is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) - - # if library is not in LOADABLE_CLASSES, then it is a custom module. - # Or if it's a pipeline module, then the module is inside the pipeline - # folder so we set the library to module name. - if is_pipeline_module: - library = pipeline_dir - elif library not in LOADABLE_CLASSES: - library = not_compiled_module.__module__ - - # retrieve class_name - class_name = not_compiled_module.__class__.__name__ - + library, class_name = _fetch_class_library_tuple(module) register_dict = {name: (library, class_name)} # save model index config @@ -601,7 +605,7 @@ def __setattr__(self, name: str, value: Any): # We need to overwrite the config if name exists in config if isinstance(getattr(self.config, name), (tuple, list)): if value is not None and self.config[name][0] is not None: - class_library_tuple = (value.__module__.split(".")[0], value.__class__.__name__) + class_library_tuple = _fetch_class_library_tuple(value) else: class_library_tuple = (None, None)