Skip to content

Commit

Permalink
[Utils] Correct custom init sort (huggingface#4967)
Browse files Browse the repository at this point in the history
* [Utils] Correct custom init sort

* [Utils] Correct custom init sort

* [Utils] Correct custom init sort

* add type checking

* fix custom init sort

* fix test

* fix tests

---------

Co-authored-by: Dhruv Nair <[email protected]>
  • Loading branch information
patrickvonplaten and DN6 authored Sep 12, 2023
1 parent d82157b commit 18b7264
Show file tree
Hide file tree
Showing 39 changed files with 1,371 additions and 498 deletions.
31 changes: 29 additions & 2 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from ..utils import _LazyModule, is_flax_available, is_torch_available


Expand Down Expand Up @@ -40,7 +42,32 @@
_import_structure["unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
_import_structure["vae_flax"] = ["FlaxAutoencoderKL"]

import sys

if TYPE_CHECKING:
if is_torch_available():
from .adapter import MultiAdapter, T2IAdapter
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_kl import AutoencoderKL
from .autoencoder_tiny import AutoencoderTiny
from .controlnet import ControlNetModel
from .dual_transformer_2d import DualTransformer2DModel
from .modeling_utils import ModelMixin
from .prior_transformer import PriorTransformer
from .t5_film_transformer import T5FilmDecoder
from .transformer_2d import Transformer2DModel
from .transformer_temporal import TransformerTemporalModel
from .unet_1d import UNet1DModel
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .unet_3d_condition import UNet3DConditionModel
from .vq_model import VQModel

if is_flax_available():
from .controlnet_flax import FlaxControlNetModel
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
from .vae_flax import FlaxAutoencoderKL

else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
Loading

0 comments on commit 18b7264

Please sign in to comment.