Skip to content

Commit

Permalink
Fix PyCharm/VSCode static type checking for dummy objects (huggingfac…
Browse files Browse the repository at this point in the history
…e#1596)

* Fix PyCharm/VSCode static type checking for dummy objects

* Re-add dummies

* Fix AudioDiffusion imports

* fix import

* fix import

* Update utils/check_dummies.py

Co-authored-by: Pedro Cuenca <[email protected]>

* Update src/diffusers/utils/import_utils.py

* Update src/diffusers/__init__.py

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/pipelines/stable_diffusion/__init__.py

* fix double import

Co-authored-by: Pedro Cuenca <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
  • Loading branch information
3 people authored Dec 8, 2022
1 parent 03566d8 commit dbe0719
Show file tree
Hide file tree
Showing 9 changed files with 148 additions and 87 deletions.
74 changes: 52 additions & 22 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from .configuration_utils import ConfigMixin
from .onnx_utils import OnnxRuntimeModel
from .utils import (
OptionalDependencyNotAvailable,
is_flax_available,
is_inflect_available,
is_k_diffusion_available,
is_librosa_available,
is_onnx_available,
is_scipy_available,
is_torch_available,
Expand All @@ -15,7 +17,12 @@
)


if is_torch_available():
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_pt_objects import * # noqa F403
else:
from .modeling_utils import ModelMixin
from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
from .optimization import (
Expand All @@ -29,14 +36,12 @@
)
from .pipeline_utils import DiffusionPipeline
from .pipelines import (
AudioDiffusionPipeline,
DanceDiffusionPipeline,
DDIMPipeline,
DDPMPipeline,
KarrasVePipeline,
LDMPipeline,
LDMSuperResolutionPipeline,
Mel,
PNDMPipeline,
RePaintPipeline,
ScoreSdeVePipeline,
Expand All @@ -60,15 +65,22 @@
VQDiffusionScheduler,
)
from .training_utils import EMAModel
else:
from .utils.dummy_pt_objects import * # noqa F403

if is_torch_available() and is_scipy_available():
from .schedulers import LMSDiscreteScheduler
else:
try:
if not (is_torch_available() and is_scipy_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_scipy_objects import * # noqa F403
else:
from .schedulers import LMSDiscreteScheduler

if is_torch_available() and is_transformers_available():

try:
if not (is_torch_available() and is_transformers_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipelines import (
AltDiffusionImg2ImgPipeline,
AltDiffusionPipeline,
Expand All @@ -88,26 +100,43 @@
VersatileDiffusionTextToImagePipeline,
VQDiffusionPipeline,
)
else:
from .utils.dummy_torch_and_transformers_objects import * # noqa F403

if is_torch_available() and is_transformers_available() and is_k_diffusion_available():
from .pipelines import StableDiffusionKDiffusionPipeline
else:
try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
else:
from .pipelines import StableDiffusionKDiffusionPipeline

if is_torch_available() and is_transformers_available() and is_onnx_available():
try:
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
else:
from .pipelines import (
OnnxStableDiffusionImg2ImgPipeline,
OnnxStableDiffusionInpaintPipeline,
OnnxStableDiffusionInpaintPipelineLegacy,
OnnxStableDiffusionPipeline,
StableDiffusionOnnxPipeline,
)

try:
if not (is_torch_available() and is_librosa_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_librosa_objects import * # noqa F403
else:
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
from .pipelines import AudioDiffusionPipeline, Mel

if is_flax_available():
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_flax_objects import * # noqa F403
else:
from .modeling_flax_utils import FlaxModelMixin
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
from .models.vae_flax import FlaxAutoencoderKL
Expand All @@ -122,10 +151,11 @@
FlaxSchedulerMixin,
FlaxScoreSdeVeScheduler,
)
else:
from .utils.dummy_flax_objects import * # noqa F403

if is_flax_available() and is_transformers_available():
from .pipelines import FlaxStableDiffusionPipeline
else:
try:
if not (is_flax_available() and is_transformers_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_flax_and_transformers_objects import * # noqa F403
else:
from .pipelines import FlaxStableDiffusionPipeline
48 changes: 38 additions & 10 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ..utils import (
OptionalDependencyNotAvailable,
is_flax_available,
is_k_diffusion_available,
is_librosa_available,
Expand All @@ -8,7 +9,12 @@
)


if is_torch_available():
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_pt_objects import * # noqa F403
else:
from .dance_diffusion import DanceDiffusionPipeline
from .ddim import DDIMPipeline
from .ddpm import DDPMPipeline
Expand All @@ -18,15 +24,21 @@
from .repaint import RePaintPipeline
from .score_sde_ve import ScoreSdeVePipeline
from .stochastic_karras_ve import KarrasVePipeline
else:
from ..utils.dummy_pt_objects import * # noqa F403

if is_torch_available() and is_librosa_available():
from .audio_diffusion import AudioDiffusionPipeline, Mel
try:
if not (is_torch_available() and is_librosa_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_librosa_objects import * # noqa F403
else:
from ..utils.dummy_torch_and_librosa_objects import AudioDiffusionPipeline, Mel # noqa F403
from .audio_diffusion import AudioDiffusionPipeline, Mel

if is_torch_available() and is_transformers_available():
try:
if not (is_torch_available() and is_transformers_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
from .latent_diffusion import LDMTextToImagePipeline
from .paint_by_example import PaintByExamplePipeline
Expand All @@ -48,7 +60,12 @@
)
from .vq_diffusion import VQDiffusionPipeline

if is_transformers_available() and is_onnx_available():
try:
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
else:
from .stable_diffusion import (
OnnxStableDiffusionImg2ImgPipeline,
OnnxStableDiffusionInpaintPipeline,
Expand All @@ -57,8 +74,19 @@
StableDiffusionOnnxPipeline,
)

if is_torch_available() and is_transformers_available() and is_k_diffusion_available():
try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
else:
from .stable_diffusion import StableDiffusionKDiffusionPipeline

if is_transformers_available() and is_flax_available():

try:
if not (is_flax_available() and is_transformers_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_flax_and_transformers_objects import * # noqa F403
else:
from .stable_diffusion import FlaxStableDiffusionPipeline
17 changes: 13 additions & 4 deletions src/diffusers/pipelines/stable_diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from ...utils import (
BaseOutput,
OptionalDependencyNotAvailable,
is_flax_available,
is_k_diffusion_available,
is_onnx_available,
Expand Down Expand Up @@ -44,12 +45,20 @@ class StableDiffusionPipelineOutput(BaseOutput):
from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
from .safety_checker import StableDiffusionSafetyChecker

if is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0.dev0"):
from .pipeline_stable_diffusion_image_variation import StableDiffusionImageVariationPipeline
else:
try:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0.dev0")):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline
else:
from .pipeline_stable_diffusion_image_variation import StableDiffusionImageVariationPipeline

if is_transformers_available() and is_torch_available() and is_k_diffusion_available():
try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
else:
from .pipeline_stable_diffusion_k_diffusion import StableDiffusionKDiffusionPipeline

if is_transformers_available() and is_onnx_available():
Expand Down
24 changes: 16 additions & 8 deletions src/diffusers/pipelines/versatile_diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
from ...utils import is_torch_available, is_transformers_available, is_transformers_version
from ...utils import (
OptionalDependencyNotAvailable,
is_torch_available,
is_transformers_available,
is_transformers_version,
)


if is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0.dev0"):
from .modeling_text_unet import UNetFlatConditionModel
from .pipeline_versatile_diffusion import VersatileDiffusionPipeline
from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline
from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline
from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline
else:
try:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0.dev0")):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import (
VersatileDiffusionDualGuidedPipeline,
VersatileDiffusionImageVariationPipeline,
VersatileDiffusionPipeline,
VersatileDiffusionTextToImagePipeline,
)
else:
from .modeling_text_unet import UNetFlatConditionModel
from .pipeline_versatile_diffusion import VersatileDiffusionPipeline
from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline
from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline
from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline
29 changes: 19 additions & 10 deletions src/diffusers/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,15 @@
# limitations under the License.


from ..utils import is_flax_available, is_scipy_available, is_torch_available
from ..utils import OptionalDependencyNotAvailable, is_flax_available, is_scipy_available, is_torch_available


if is_torch_available():
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_pt_objects import * # noqa F403
else:
from .scheduling_ddim import DDIMScheduler
from .scheduling_ddpm import DDPMScheduler
from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
Expand All @@ -34,10 +39,13 @@
from .scheduling_sde_vp import ScoreSdeVpScheduler
from .scheduling_utils import SchedulerMixin
from .scheduling_vq_diffusion import VQDiffusionScheduler
else:
from ..utils.dummy_pt_objects import * # noqa F403

if is_flax_available():
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_flax_objects import * # noqa F403
else:
from .scheduling_ddim_flax import FlaxDDIMScheduler
from .scheduling_ddpm_flax import FlaxDDPMScheduler
from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler
Expand All @@ -46,11 +54,12 @@
from .scheduling_pndm_flax import FlaxPNDMScheduler
from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
else:
from ..utils.dummy_flax_objects import * # noqa F403


if is_scipy_available() and is_torch_available():
from .scheduling_lms_discrete import LMSDiscreteScheduler
else:
try:
if not (is_torch_available() and is_scipy_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_scipy_objects import * # noqa F403
else:
from .scheduling_lms_discrete import LMSDiscreteScheduler
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
USE_TF,
USE_TORCH,
DummyObject,
OptionalDependencyNotAvailable,
is_accelerate_available,
is_flax_available,
is_inflect_available,
Expand Down
30 changes: 0 additions & 30 deletions src/diffusers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,21 +152,6 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class AudioDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class DanceDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch"]

Expand Down Expand Up @@ -257,21 +242,6 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class Mel(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class PNDMPipeline(metaclass=DummyObject):
_backends = ["torch"]

Expand Down
Loading

0 comments on commit dbe0719

Please sign in to comment.