Skip to content

Commit

Permalink
[Type Hints] DDIM pipelines (huggingface#345)
Browse files Browse the repository at this point in the history
* type hints

* Apply suggestions from code review

Co-authored-by: Anton Lozhkov <[email protected]>
  • Loading branch information
sidthekidder and anton-l authored Sep 5, 2022
1 parent cc59b05 commit ada09bd
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 10 deletions.
12 changes: 6 additions & 6 deletions src/diffusers/pipelines/ddim/pipeline_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


import warnings
from typing import Tuple, Union
from typing import Optional, Tuple, Union

import torch

Expand All @@ -31,11 +31,11 @@ def __init__(self, unet, scheduler):
@torch.no_grad()
def __call__(
self,
batch_size=1,
generator=None,
eta=0.0,
num_inference_steps=50,
output_type="pil",
batch_size: int = 1,
generator: Optional[torch.Generator] = None,
eta: float = 0.0,
num_inference_steps: int = 50,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
) -> Union[ImagePipelineOutput, Tuple]:
Expand Down
9 changes: 7 additions & 2 deletions src/diffusers/pipelines/ddpm/pipeline_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


import warnings
from typing import Tuple, Union
from typing import Optional, Tuple, Union

import torch

Expand All @@ -30,7 +30,12 @@ def __init__(self, unet, scheduler):

@torch.no_grad()
def __call__(
self, batch_size=1, generator=None, output_type="pil", return_dict: bool = True, **kwargs
self,
batch_size: int = 1,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
) -> Union[ImagePipelineOutput, Tuple]:
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,23 @@
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import BaseModelOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.utils import logging

from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler


class LDMTextToImagePipeline(DiffusionPipeline):
def __init__(self, vqvae, bert, tokenizer, unet, scheduler):
def __init__(
self,
vqvae: Union[VQModel, AutoencoderKL],
bert: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
unet: Union[UNet2DModel, UNet2DConditionModel],
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
Expand Down Expand Up @@ -618,7 +628,7 @@ def custom_forward(*inputs):


class LDMBertModel(LDMBertPreTrainedModel):
def __init__(self, config):
def __init__(self, config: LDMBertConfig):
super().__init__(config)
self.model = LDMBertEncoder(config)
self.to_logits = nn.Linear(config.hidden_size, config.vocab_size)
Expand Down

0 comments on commit ada09bd

Please sign in to comment.