|
16 | 16 | # and https://github.com/hojonathanho/diffusion
|
17 | 17 |
|
18 | 18 | import math
|
19 |
| -from typing import Union |
| 19 | +from typing import Optional, Union |
20 | 20 |
|
21 | 21 | import numpy as np
|
22 | 22 | import torch
|
@@ -52,15 +52,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
52 | 52 | @register_to_config
|
53 | 53 | def __init__(
|
54 | 54 | self,
|
55 |
| - num_train_timesteps=1000, |
56 |
| - beta_start=0.0001, |
57 |
| - beta_end=0.02, |
58 |
| - beta_schedule="linear", |
59 |
| - trained_betas=None, |
60 |
| - timestep_values=None, |
61 |
| - clip_sample=True, |
62 |
| - set_alpha_to_one=True, |
63 |
| - tensor_format="pt", |
| 55 | + num_train_timesteps: int = 1000, |
| 56 | + beta_start: float = 0.0001, |
| 57 | + beta_end: float = 0.02, |
| 58 | + beta_schedule: str = "linear", |
| 59 | + trained_betas: Optional[np.ndarray] = None, |
| 60 | + timestep_values: Optional[np.ndarray] = None, |
| 61 | + clip_sample: bool = True, |
| 62 | + set_alpha_to_one: bool = True, |
| 63 | + tensor_format: str = "pt", |
64 | 64 | ):
|
65 | 65 |
|
66 | 66 | if beta_schedule == "linear":
|
@@ -100,7 +100,7 @@ def _get_variance(self, timestep, prev_timestep):
|
100 | 100 |
|
101 | 101 | return variance
|
102 | 102 |
|
103 |
| - def set_timesteps(self, num_inference_steps, offset=0): |
| 103 | + def set_timesteps(self, num_inference_steps: int, offset: int = 0): |
104 | 104 | self.num_inference_steps = num_inference_steps
|
105 | 105 | self.timesteps = np.arange(
|
106 | 106 | 0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
|
@@ -176,7 +176,12 @@ def step(
|
176 | 176 |
|
177 | 177 | return {"prev_sample": prev_sample}
|
178 | 178 |
|
179 |
| - def add_noise(self, original_samples, noise, timesteps): |
| 179 | + def add_noise( |
| 180 | + self, |
| 181 | + original_samples: Union[torch.FloatTensor, np.ndarray], |
| 182 | + noise: Union[torch.FloatTensor, np.ndarray], |
| 183 | + timesteps: Union[torch.IntTensor, np.ndarray], |
| 184 | + ) -> Union[torch.FloatTensor, np.ndarray]: |
180 | 185 | sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
181 | 186 | sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
|
182 | 187 | sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
|
0 commit comments