8000 [Type hint] scheduling ddim (#343) · yoonseokjin/diffusers@917b137 · GitHub 10000
[go: up one dir, main page]

Skip to content

Commit 917b137

Browse files
[Type hint] scheduling ddim (huggingface#343)
* [Type hint] scheduling ddim * apply suggestions from code review apply suggestions to also return the return type Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
1 parent fdcfb27 commit 917b137

File tree

1 file changed

+17
-12
lines changed

1 file changed

+17
-12
lines changed

schedulers/scheduling_ddim.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# and https://github.com/hojonathanho/diffusion
1717

1818
import math
19-
from typing import Union
19+
from typing import Optional, Union
2020

2121
import numpy as np
2222
import torch
@@ -52,15 +52,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
5252
@register_to_config
5353
def __init__(
5454
self,
55-
num_train_timesteps=1000,
56-
beta_start=0.0001,
57-
beta_end=0.02,
58-
beta_schedule="linear",
59-
trained_betas=None,
60-
timestep_values=None,
61-
clip_sample=True,
62-
set_alpha_to_one=True,
63-
tensor_format="pt",
55+
num_train_timesteps: int = 1000,
56+
beta_start: float = 0.0001,
57+
beta_end: float = 0.02,
58+
beta_schedule: str = "linear",
59+
trained_betas: Optional[np.ndarray] = None,
60+
timestep_values: Optional[np.ndarray] = None,
61+
clip_sample: bool = True,
62+
set_alpha_to_one: bool = True,
63+
tensor_format: str = "pt",
6464
):
6565

6666
if beta_schedule == "linear":
@@ -100,7 +100,7 @@ def _get_variance(self, timestep, prev_timestep):
100100

101101
return variance
102102

103-
def set_timesteps(self, num_inference_steps, offset=0):
103+
def set_timesteps(self, num_inference_steps: int, offset: int = 0):
104104
self.num_inference_steps = num_inference_steps
105105
self.timesteps = np.arange(
106106
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
@@ -176,7 +176,12 @@ def step(
176176

177177
return {"prev_sample": prev_sample}
178178

179-
def add_noise(self, original_samples, noise, timesteps):
179+
def add_noise(
180+
self,
181+
original_samples: Union[torch.FloatTensor, np.ndarray],
182+
noise: Union[torch.FloatTensor, np.ndarray],
183+
timesteps: Union[torch.IntTensor, np.ndarray],
184+
) -> Union[torch.FloatTensor, np.ndarray]:
180185
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
181186
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
182187
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5

0 commit comments

Comments
 (0)
0