|
16 | 16 | # and https://github.com/hojonathanho/diffusion
|
17 | 17 |
|
18 | 18 | import math
|
19 |
| -from typing import Union |
| 19 | +from typing import Optional, Union |
20 | 20 |
|
21 | 21 | import numpy as np
|
22 | 22 | import torch
|
@@ -73,15 +73,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
73 | 73 | @register_to_config
|
74 | 74 | def __init__(
|
75 | 75 | self,
|
76 |
| - num_train_timesteps=1000, |
77 |
| - beta_start=0.0001, |
78 |
| - beta_end=0.02, |
79 |
| - beta_schedule="linear", |
80 |
| - trained_betas=None, |
81 |
| - timestep_values=None, |
82 |
| - clip_sample=True, |
83 |
| - set_alpha_to_one=True, |
84 |
| - tensor_format="pt", |
| 76 | + num_train_timesteps: int = 1000, |
| 77 | + beta_start: float = 0.0001, |
| 78 | + beta_end: float = 0.02, |
| 79 | + beta_schedule: str = "linear", |
| 80 | + trained_betas: Optional[np.ndarray] = None, |
| 81 | + timestep_values: Optional[np.ndarray] = None, |
| 82 | + clip_sample: bool = True, |
| 83 | + set_alpha_to_one: bool = True, |
| 84 | + tensor_format: str = "pt", |
85 | 85 | ):
|
86 | 86 | if trained_betas is not None:
|
87 | 87 | self.betas = np.asarray(trained_betas)
|
@@ -122,7 +122,7 @@ def _get_variance(self, timestep, prev_timestep):
|
122 | 122 |
|
123 | 123 | return variance
|
124 | 124 |
|
125 |
| - def set_timesteps(self, num_inference_steps, offset=0): |
| 125 | + def set_timesteps(self, num_inference_steps: int, offset: int = 0): |
126 | 126 | self.num_inference_steps = num_inference_steps
|
127 | 127 | self.timesteps = np.arange(
|
128 | 128 | 0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
|
@@ -198,7 +198,12 @@ def step(
|
198 | 198 |
|
199 | 199 | return {"prev_sample": prev_sample}
|
200 | 200 |
|
201 |
| - def add_noise(self, original_samples, noise, timesteps): |
| 201 | + def add_noise( |
| 202 | + self, |
| 203 | + original_samples: Union[torch.FloatTensor, np.ndarray], |
| 204 | + noise: Union[torch.FloatTensor, np.ndarray], |
| 205 | + timesteps: Union[torch.IntTensor, np.ndarray], |
| 206 | + ) -> Union[torch.FloatTensor, np.ndarray]: |
202 | 207 | sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
203 | 208 | sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
|
204 | 209 | sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
|
0 commit comments