8000 remove void, add types for params · huggingface/diffusers@45df73e · GitHub
[go: up one dir, main page]

Skip to content
Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 45df73e

Browse files
committed
remove void, add types for params
1 parent a273710 commit 45df73e

File tree

2 files changed

+43
-38
lines changed

2 files changed

+43
-38
lines changed

src/diffusers/models/unet_2d.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Union
1+
from typing import Dict, Optional, Tuple, Union
22

33
import torch
44
import torch.nn as nn
@@ -13,24 +13,24 @@ class UNet2DModel(ModelMixin, ConfigMixin):
1313
@register_to_config
1414
def __init__(
1515
self,
16-
sample_size=None,
17-
in_channels=3,
18-
out_channels=3,
19-
center_input_sample=False,
20-
time_embedding_type="positional",
21-
freq_shift=0,
22-
flip_sin_to_cos=True,
23-
down_block_types=("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
24-
up_block_types=("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
25-
block_out_channels=(224, 448, 672, 896),
26-
layers_per_block=2,
27-
mid_block_scale_factor=1,
28-
downsample_padding=1,
29-
act_fn="silu",
30-
attention_head_dim=8,
31-
norm_num_groups=32,
32-
norm_eps=1e-5,
33-
) -> None:
16+
sample_size: Optional[int] = None,
17+
in_channels: int = 3,
18+
out_channels: int = 3,
19+
center_input_sample: bool = False,
20+
time_embedding_type: str = "positional",
21+
freq_shift: int = 0,
22+
flip_sin_to_cos: bool = True,
23+
down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
24+
up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
25+
block_out_channels: Tuple[int] = (224, 448, 672, 896),
26+
layers_per_block: int = 2,
27+
mid_block_scale_factor: float = 1,
28+
downsample_padding: int = 1,
29+
act_fn: str = "silu",
30+
attention_head_dim: int = 8,
31+
norm_num_groups: int = 32,
32+
norm_eps: float = 1e-5,
33+
):
3434
super().__init__()
3535

3636
self.sample_size = sample_size

src/diffusers/models/unet_2d_condition.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Union
1+
from typing import Dict, Optional, Tuple, Union
22

33
import torch
44
import torch.nn as nn
@@ -13,24 +13,29 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
1313
@register_to_config
1414
def __init__(
1515
self,
16-
sample_size=None,
17-
in_channels=4,
18-
out_channels=4,
19-
center_input_sample=False,
20-
flip_sin_to_cos=True,
21-
freq_shift=0,
22-
down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"),
23-
up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
24-
block_out_channels=(320, 640, 1280, 1280),
25-
layers_per_block=2,
26-
downsample_padding=1,
27-
mid_block_scale_factor=1,
28-
act_fn="silu",
29-
norm_num_groups=32,
30-
norm_eps=1e-5,
31-
cross_attention_dim=1280,
32-
attention_head_dim=8,
33-
) -> None:
16+
sample_size: Optional[int] = None,
17+
in_channels: int = 4,
18+
out_channels: int = 4,
19+
center_input_sample: bool = False,
20+
flip_sin_to_cos: bool = True,
21+
freq_shift: int = 0,
22+
down_block_types: Tuple[str] = (
23+
"CrossAttnDownBlock2D",
24+
"CrossAttnDownBlock2D",
25+
"CrossAttnDownBlock2D",
26+
"DownBlock2D",
27+
),
28+
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
29+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
30+
layers_per_block: int = 2,
31+
downsample_padding: int = 1,
32+
mid_block_scale_factor: float = 1,
33+
act_fn: str = "silu",
34+
norm_num_groups: int = 32,
35+
norm_eps: float = 1e-5,
36+
cross_attention_dim: int = 1280,
37+
attention_head_dim: int = 8,
38+
):
3439
super().__init__()
3540

3641
self.sample_size = sample_size

0 commit comments

Comments
 (0)
0