8000 [Type hint] scheduling ddim (#343) · huggingface/diffusers@5095a1d · GitHub
[go: up one dir, main page]

Skip to content

Commit 5095a1d

Browse files
santiviquezpatrickvonplaten
authored andcommitted
[Type hint] scheduling ddim (#343)
* [Type hint] scheduling ddim * apply suggestions from code review apply suggestions to also return the return type Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
1 parent 8172831 commit 5095a1d

File tree

1 file changed

+17
-12
lines changed

1 file changed

+17
-12
lines changed

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# and https://github.com/hojonathanho/diffusion
1717

1818
import math
19-
from typing import Union
19+
from typing import Optional, Union
2020

2121
import numpy as np
2222
import torch
@@ -73,15 +73,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
7373
@register_to_config
7474
def __init__(
7575
self,
76-
num_train_timesteps=1000,
77-
beta_start=0.0001,
78-
beta_end=0.02,
79-
beta_schedule="linear",
80-
trained_betas=None,
81-
timestep_values=None,
82-
clip_sample=True,
83-
set_alpha_to_one=True,
84-
tensor_format="pt",
76+
num_train_timesteps: int = 1000,
77+
beta_start: float = 0.0001,
78+
beta_end: float = 0.02,
79+
beta_schedule: str = "linear",
80+
trained_betas: Optional[np.ndarray] = None,
81+
timestep_values: Optional[np.ndarray] = None,
82+
clip_sample: bool = True,
83+
set_alpha_to_one: bool = True,
84+
tensor_format: str = "pt",
8585
):
8686
if trained_betas is not None:
8787
self.betas = np.asarray(trained_betas)
@@ -122,7 +122,7 @@ def _get_variance(self, timestep, prev_timestep):
122122

123123
return variance
124124

125-
def set_timesteps(self, num_inference_steps, offset=0):
125+
def set_timesteps(self, num_inference_steps: int, offset: int = 0):
126126
self.num_inference_steps = num_inference_steps
127127
self.timesteps = np.arange(
128128
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
@@ -198,7 +198,12 @@ def step(
198198

199199
return {"prev_sample": prev_sample}
200200

201-
def add_noise(self, original_samples, noise, timesteps):
201+
def add_noise(
202+
self,
203+
original_samples: Union[torch.FloatTensor, np.ndarray],
204+
noise: Union[torch.FloatTensor, np.ndarray],
205+
timesteps: Union[torch.IntTensor, np.ndarray],
206+
) -> Union[torch.FloatTensor, np.ndarray]:
202207
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
203208
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
204209
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5

0 commit comments

Comments
 (0)
0