8000 [docs sprint] schedulers docs, will update (#376) · huggingface/diffusers@e6110f6 · GitHub
[go: up one dir, main page]

Skip to content

Commit e6110f6

Browse files
Nathan Lambertdasparthosantiviquezpatrickvonplaten
authored
[docs sprint] schedulers docs, will update (#376)
* init schedulers docs * add some docstrings, fix sidebar formatting * add docstrings * [Type hint] PNDM schedulers (#335) * [Type hint] PNDM Schedulers * ran make style * updated timesteps type hint * apply suggestions from code review * ran make style * removed unused import * [Type hint] scheduling ddim (#343) * [Type hint] scheduling ddim * apply suggestions from code review apply suggestions to also return the return type Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * make style * update class docstrings * add docstrings * missed merge edit * add general docs page * modify headings for right sidebar Co-authored-by: Partho <parthodas6176@gmail.com> Co-authored-by: Santiago Víquez <santi.viquez@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
1 parent cee3aa0 commit e6110f6

File tree

9 files changed

+470
-63
lines changed

9 files changed

+470
-63
lines changed

docs/source/api/schedulers.mdx

Lines changed: 86 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,95 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
1010
specific language governing permissions and limitations under the License.
1111
-->
1212

13-
# Models
13+
# Schedulers
14+
15+
Diffusers contains multiple pre-built schedule functions for the diffusion process.
16+
17+
## What is a schduler?
18+
The schedule functions, denoted *Schedulers* in the library take in the output of a trained model, a sample which the diffusion process is iterating on, and a timestep to return a denoised sample.
19+
20+
- Schedulers define the methodology for iteratively adding noise to an image or for updating a sample based on model outputs.
21+
- adding noise in different manners represent the algorithmic processes to train a diffusion model by adding noise to images.
22+
- for inference, the scheduler defines how to update a sample based on an output from a pretrained model.
23+
- Schedulers are often defined by a *noise schedule* and an *update rule* to solve the differential equation solution.
24+
25+
### Discrete versus continuous schedulers
26+
All schedulers take in a timestep to predict the updated version of the sample being diffused.
27+
The timesteps dictate where in the diffusion process the step is, where data is generated by iterating forward in time and inference is executed by propagating backwards through timesteps.
28+
Different algorithms use timesteps that both discrete (accepting `int` inputs), such as the [`DDPMScheduler`] or [`PNDMScheduler`], and continuous (accepting 'float` inputs), such as the score-based schedulers [`ScoreSdeVeScheduler`] or [`ScoreSdeVpScheduler`].
29+
30+
## Designing Re-usable schedulers
31+
The core design principle between the schedule functions is to be model, system, and framework independent.
32+
This allows for rapid experimentation and cleaner abstractions in the code, where the model prediction is separated from the sample update.
33+
To this end, the design of schedulers is such that:
34+
- Schedulers can be used interchangeably between diffusion models in inference to find the preferred trade-off between speed and generation quality.
35+
- Schedulers are currently by default in PyTorch, but are designed to be framework independent (partial Numpy support currently exists).
1436
15-
Diffusers contains pretrained models for popular algorithms and modules for creating the next set of diffusion models.
16-
The primary function of these models is to denoise an input sample, by modeling the distribution $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$.
17-
The models are built on the base class ['ModelMixin'] that is a `torch.nn.module` with basic functionality for saving and loading models both locally and from the HuggingFace hub.
1837
1938
## API
39+
The core API for any new scheduler must follow a limited structure.
40+
- Schedulers should provide one or more `def step(...)` functions that should be called to update the generated sample iteratively.
41+
- Schedulers should provide a `set_timesteps(...)` method that configures the parameters of a schedule function for a specific inference task.
42+
- Schedulers should be framework-agonstic, but provide a simple functionality to convert the scheduler into a specific framework, such as PyTorch
43+
with a `set_format(...)` method.
44+
45+
### Core
46+
The base class [`SchedulerMixin`] implements low level utilities used by multiple schedulers.
47+
48+
#### SchedulerMixin
49+
[[autodoc]] SchedulerMixin
50+
51+
#### SchedulerOutput
52+
The class [`SchedulerOutput`] contains the ouputs from any schedulers `step(...)` call.
53+
[[autodoc]] schedulers.scheduling_utils.SchedulerOutput
54+
55+
### Existing Schedulers
56+
57+
#### Denoising diffusion implicit models (DDIM)
58+
59+
Original paper can be found here.
60+
61+
[[autodoc]] schedulers.scheduling_ddim.DDIMScheduler
62+
63+
#### Denoising diffusion probabilistic models (DDPM)
64+
65+
Original paper can be found [here](https://arxiv.org/abs/2010.02502).
66+
67+
[[autodoc]] schedulers.scheduling_ddpm.DDPMScheduler
68+
69+
#### Varience exploding, stochastic sampling from Karras et. al
70+
71+
Original paper can be found [here](https://arxiv.org/abs/2006.11239).
72+
73+
[[autodoc]] schedulers.scheduling_karras_ve.KarrasVeScheduler
74+
75+
#### Linear multistep scheduler for discrete beta schedules
76+
77+
Original implementation can be found [here](https://arxiv.org/abs/2206.00364).
78+
79+
80+
[[autodoc]] schedulers.scheduling_lms_discrete.LMSDiscreteScheduler
81+
82+
#### Pseudo numerical methods for diffusion models (PNDM)
83+
84+
Original implementation can be found [here](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181).
85+
86+
[[autodoc]] schedulers.scheduling_pndm.PNDMScheduler
87+
88+
#### variance exploding stochastic differential equation (SDE) scheduler
89+
90+
Original paper can be found [here](https://arxiv.org/abs/2011.13456).
91+
92+
[[autodoc]] schedulers.scheduling_sde_ve.ScoreSdeVeScheduler
93+
94+
#### variance preserving stochastic differential equation (SDE) scheduler
95+
96+
Original paper can be found [here](https://arxiv.org/abs/2011.13456).
97+
98+
<Tip warning={true}>
2099
21-
Models should provide the `def forward` function and initialization of the model.
22-
All saving, loading, and utilities should be in the base ['ModelMixin'] class.
100+
Score SDE-VP is under construction.
23101
24-
## Examples
102+
</Tip>
25103
26-
- The ['UNetModel'] was proposed in [TODO](https://arxiv.org/) and has been used in paper1, paper2, paper3.
27-
- Extensions of the ['UNetModel'] include the ['UNetGlideModel'] that uses attention and timestep embeddings for the [GLIDE](https://arxiv.org/abs/2112.10741) paper, the ['UNetGradTTS'] model from this [paper](https://arxiv.org/abs/2105.06337) for text-to-speech, ['UNetLDMModel'] for latent-diffusion models in this [paper](https://arxiv.org/abs/2112.10752), and the ['TemporalUNet'] used for time-series prediciton in this reinforcement learning [paper](https://arxiv.org/abs/2205.09991).
28-
- TODO: mention VAE / SDE score estimation
104+
[[autodoc]] schedulers.scheduling_sde_vp.ScoreSdeVpScheduler

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,17 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
3030
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
3131
(1-beta) over time from t = [0,1].
3232
33-
:param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t
34-
from 0 to 1 and
35-
produces the cumulative product of (1-beta) up to that part of the diffusion process.
36-
:param max_beta: the maximum beta to use; use values lower than 1 to
33+
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
34+
to that part of the diffusion process.
35+
36+
37+
Args:
38+
num_diffusion_timesteps (`int`): the number of betas to produce.
39+
max_beta (`float`): the maximum beta to use; use values lower than 1 to
3740
prevent singularities.
41+
42+
Returns:
43+
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
3844
"""
3945

4046
def alpha_bar(time_step):
@@ -49,6 +55,29 @@ def alpha_bar(time_step):
4955

5056

5157
class DDIMScheduler(SchedulerMixin, ConfigMixin):
58+
"""
59+
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
60+
diffusion probabilistic models (DDPMs) with non-Markovian guidance.
61+
62+
For more details, see the original paper: https://arxiv.org/abs/2010.02502
63+
64+
Args:
65+
num_train_timesteps (`int`): number of diffusion steps used to train the model.
66+
beta_start (`float`): the starting `beta` value of inference.
67+
beta_end (`float`): the final `beta` value.
68+
beta_schedule (`str`):
69+
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
70+
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
71+
trained_betas (`np.ndarray`, optional): TODO
72+
timestep_values (`np.ndarray`, optional): TODO
73+
clip_sample (`bool`, default `True`):
74+
option to clip predicted sample between -1 and 1 for numerical stability.
75+
set_alpha_to_one (`bool`, default `True`):
76+
if alpha for final step is 1 or the final alpha of the "non-previous" one.
77+
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
78+
79+
"""
80+
5281
@register_to_config
5382
def __init__(
5483
self,
@@ -62,7 +91,8 @@ def __init__(
6291
set_alpha_to_one: bool = True,
6392
tensor_format: str = "pt",
6493
):
65-
94+
if trained_betas is not None:
95+
self.betas = np.asarray(trained_betas)
6696
if beta_schedule == "linear":
6797
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
6898
elif beta_schedule == "scaled_linear":
@@ -101,6 +131,14 @@ def _get_variance(self, timestep, prev_timestep):
101131
return variance
102132

103133
def set_timesteps(self, num_inference_steps: int, offset: int = 0):
134+
"""
135+
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
136+
137+
Args:
138+
num_inference_steps (`int`):
139+
the number of diffusion steps used when generating samples with a pre-trained model.
140+
offset (`int`): TODO
141+
"""
104142
self.num_inference_steps = num_inference_steps
105143
self.timesteps = np.arange(
106144
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
@@ -118,7 +156,24 @@ def step(
118156
generator=None,
119157
return_dict: bool = True,
120158
) -> Union[SchedulerOutput, Tuple]:
121-
159+
"""
160+
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
161+
process from the learned model outputs (most often the predicted noise).
162+
163+
Args:
164+
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
165+
timestep (`int`): current discrete timestep in the diffusion chain.
166+
sample (`torch.FloatTensor` or `np.ndarray`):
167+
current instance of sample being created by diffusion process.
168+
eta (`float`): weight of noise for added noise in diffusion step.
169+
use_clipped_model_output (`bool`): TODO
170+
generator: random number generator.
171+
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
172+
173+
Returns:
174+
`SchedulerOutput`: updated sample in the diffusion chain.
175+
176+
"""
122177
if self.num_inference_steps is None:
123178
raise ValueError(
124179
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"

src/diffusers/schedulers/scheduling_ddpm.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,17 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
2929
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
3030
(1-beta) over time from t = [0,1].
3131
32-
:param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t
33-
from 0 to 1 and
34-
produces the cumulative product of (1-beta) up to that part of the diffusion process.
35-
:param max_beta: the maximum beta to use; use values lower than 1 to
32+
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
33+
to that part of the diffusion process.
34+
35+
36+
Args:
37+
num_diffusion_timesteps (`int`): the number of betas to produce.
38+
max_beta (`float`): the maximum beta to use; use values lower than 1 to
3639
prevent singularities.
40+
41+
Returns:
42+
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
3743
"""
3844

3945
def alpha_bar(time_step):
@@ -48,6 +54,29 @@ def alpha_bar(time_step):
4854

4955

5056
class DDPMScheduler(SchedulerMixin, ConfigMixin):
57+
"""
58+
Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
59+
Langevin dynamics sampling.
60+
61+
For more details, see the original paper: https://arxiv.org/abs/2006.11239
62+
63+
Args:
64+
num_train_timesteps (`int`): number of diffusion steps used to train the model.
65+
beta_start (`float`): the starting `beta` value of inference.
66+
beta_end (`float`): the final `beta` value.
67+
beta_schedule (`str`):
68+
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
69+
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
70+
trained_betas (`np.ndarray`, optional): TODO
71+
variance_type (`str`):
72+
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
73+
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
74+
clip_sample (`bool`, default `True`):
75+
option to clip predicted sample between -1 and 1 for numerical stability.
76+
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
77+
78+
"""
79+
5180
@register_to_config
5281
def __init__(
5382
self,
@@ -88,6 +117,13 @@ def __init__(
88117
self.variance_type = variance_type
89118

90119
def set_timesteps(self, num_inference_steps: int):
120+
"""
121+
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
122+
123+
Args:
124+
num_inference_steps (`int`):
125+
the number of diffusion steps used when generating samples with a pre-trained model.
126+
"""
91127
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
92128
self.num_inference_steps = num_inference_steps
93129
self.timesteps = np.arange(
@@ -137,7 +173,25 @@ def step(
137173
generator=None,
138174
return_dict: bool = True,
139175
) -> Union[SchedulerOutput, Tuple]:
140-
176+
"""
177+
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
178+
process from the learned model outputs (most often the predicted noise).
179+
180+
Args:
181+
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
182+
timestep (`int`): current discrete timestep in the diffusion chain.
183+
sample (`torch.FloatTensor` or `np.ndarray`):
184+
current instance of sample being created by diffusion process.
185+
eta (`float`): weight of noise for added noise in diffusion step.
186+
predict_epsilon (`bool`):
187+
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
188+
generator: random number generator.
189+
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
190+
191+
Returns:
192+
`SchedulerOutput`: updated sample in the diffusion chain.
193+
194+
"""
141195
t = timestep
142196

143197
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:

0 commit comments

Comments
 (0)
0