1
- from typing import Dict , Union
1
+ from typing import Dict , Optional , Tuple , Union
2
2
3
3
import torch
4
4
import torch .nn as nn
@@ -13,24 +13,29 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
13
13
@register_to_config
14
14
def __init__ (
15
15
self ,
16
- sample_size = None ,
17
- in_channels = 4 ,
18
- out_channels = 4 ,
19
- center_input_sample = False ,
20
- flip_sin_to_cos = True ,
21
- freq_shift = 0 ,
22
- down_block_types = ("CrossAttnDownBlock2D" , "CrossAttnDownBlock2D" , "CrossAttnDownBlock2D" , "DownBlock2D" ),
23
- up_block_types = ("UpBlock2D" , "CrossAttnUpBlock2D" , "CrossAttnUpBlock2D" , "CrossAttnUpBlock2D" ),
24
- block_out_channels = (320 , 640 , 1280 , 1280 ),
25
- layers_per_block = 2 ,
26
- downsample_padding = 1 ,
27
- mid_block_scale_factor = 1 ,
28
- act_fn = "silu" ,
29
- norm_num_groups = 32 ,
30
- norm_eps = 1e-5 ,
31
- cross_attention_dim = 1280 ,
32
- attention_head_dim = 8 ,
33
- ) -> None :
16
+ sample_size : Optional [int ] = None ,
17
+ in_channels : int = 4 ,
18
+ out_channels : int = 4 ,
19
+ center_input_sample : bool = False ,
20
+ flip_sin_to_cos : bool = True ,
21
+ freq_shift : int = 0 ,
22
+ down_block_types : Tuple [str ] = (
23
+ "CrossAttnDownBlock2D" ,
24
+ "CrossAttnDownBlock2D" ,
25
+ "CrossAttnDownBlock2D" ,
26
+ "DownBlock2D" ,
27
+ ),
28
+ up_block_types : Tuple [str ] = ("UpBlock2D" , "CrossAttnUpBlock2D" , "CrossAttnUpBlock2D" , "CrossAttnUpBlock2D" ),
29
+ block_out_channels : Tuple [int ] = (320 , 640 , 1280 , 1280 ),
30
+ layers_per_block : int = 2 ,
31
+ downsample_padding : int = 1 ,
32
+ mid_block_scale_factor : float = 1 ,
33
+ act_fn : str = "silu" ,
34
+ norm_num_groups : int = 32 ,
35
+ norm_eps : float = 1e-5 ,
36
+ cross_attention_dim : int = 1280 ,
37
+ attention_head_dim : int = 8 ,
38
+ ):
34
39
super ().__init__ ()
35
40
36
41
self .sample_size = sample_size
0 commit comments