8000 Let chroma TE work on regular flux. (#8429) · psy-repos-python/ComfyUI@4248b16 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4248b16

Browse files
Let chroma TE work on regular flux. (comfyanonymous#8429)
1 parent 866f6cd commit 4248b16

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

comfy/ldm/flux/controlnet.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,9 @@ def forward_orig(
121121
if img.ndim != 3 or txt.ndim != 3:
122122
raise ValueError("Input img and txt tensors must have 3 dimensions.")
123123

124+
if y is None:
125+
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
126+
124127
# running on sequences img
125128
img = self.img_in(img)
126129

@@ -174,7 +177,7 @@ def forward_orig(
174177
out["output"] = out_output[:self.main_model_single]
175178
return out
176179

177-
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
180+
def forward(self, x, timesteps, context, y=None, guidance=None, hint=None, **kwargs):
178181
patch_size = 2
179182
if self.latent_input:
180183
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))

comfy/ldm/flux/model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,10 @@ def forward_orig(
101101
transformer_options={},
102102
attn_mask: Tensor = None,
103103
) -> Tensor:
104+
105+
if y is None:
106+
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
107+
104108
patches_replace = transformer_options.get("patches_replace", {})
105109
if img.ndim != 3 or txt.ndim != 3:
106110
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -188,7 +192,7 @@ def block_wrap(args):
188192
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
189193
return img
190194

191-
def forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs):
195+
def forward(self, x, timestep, context, y=None, guidance=None, control=None, transformer_options={}, **kwargs):
192196
bs, c, h, w = x.shape
193197
patch_size = self.patch_size
194198
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))

0 commit comments

Comments
 (0)
0