diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index db4c33690c9d..a04b9034b9de 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -1,4 +1,4 @@ -from typing import Dict, Union +from typing import Dict, Optional, Tuple, Union import torch import torch.nn as nn @@ -13,23 +13,23 @@ class UNet2DModel(ModelMixin, ConfigMixin): @register_to_config def __init__( self, - sample_size=None, - in_channels=3, - out_channels=3, - center_input_sample=False, - time_embedding_type="positional", - freq_shift=0, - flip_sin_to_cos=True, - down_block_types=("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"), - up_block_types=("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"), - block_out_channels=(224, 448, 672, 896), - layers_per_block=2, - mid_block_scale_factor=1, - downsample_padding=1, - act_fn="silu", - attention_head_dim=8, - norm_num_groups=32, - norm_eps=1e-5, + sample_size: Optional[int] = None, + in_channels: int = 3, + out_channels: int = 3, + center_input_sample: bool = False, + time_embedding_type: str = "positional", + freq_shift: int = 0, + flip_sin_to_cos: bool = True, + down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"), + up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"), + block_out_channels: Tuple[int] = (224, 448, 672, 896), + layers_per_block: int = 2, + mid_block_scale_factor: float = 1, + downsample_padding: int = 1, + act_fn: str = "silu", + attention_head_dim: int = 8, + norm_num_groups: int = 32, + norm_eps: float = 1e-5, ): super().__init__() diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 25c4e37d8a6d..d4cab4dd905a 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -1,4 +1,4 @@ -from typing import Dict, Union +from typing import Dict, Optional, Tuple, Union import torch import torch.nn as nn @@ -13,23 +13,28 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): @register_to_config def __init__( self, - sample_size=None, - in_channels=4, - out_channels=4, - center_input_sample=False, - flip_sin_to_cos=True, - freq_shift=0, - down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"), - up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), - block_out_channels=(320, 640, 1280, 1280), - layers_per_block=2, - downsample_padding=1, - mid_block_scale_factor=1, - act_fn="silu", - norm_num_groups=32, - norm_eps=1e-5, - cross_attention_dim=1280, - attention_head_dim=8, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: int = 8, ): super().__init__()